Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #408: Fit the susceptible population #904

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
estimate_infections()
```

- Added support for fitting the susceptible population size. By @seabbs in #904 and reviewed by @sbfnk.
- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.
Expand Down
6 changes: 4 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
breakpoints = breakpoints,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
pop = rt$pop,
use_pop = as.integer(rt$pop != Fixed(0)) + as.integer(rt$estimate_pop),
stationary = as.numeric(rt$gp_on == "R0"),
future_time = horizon - future_rt$from
)
Expand Down Expand Up @@ -567,11 +567,13 @@ create_stan_data <- function(data, seeding_time,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
pop = rt$pop,
lower_bounds = c(
alpha = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
rep_phi = 0,
pop = 0
)
)
)
Expand Down
40 changes: 34 additions & 6 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,17 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
#' conservative estimate of break point changes (alter this by setting
#' `gp = NULL`).
#'
#' @param pop Integer, defaults to 0. Susceptible population initially present.
#' Used to adjust Rt estimates when otherwise fixed based on the proportion of
#' the population that is susceptible. When set to 0 no population adjustment
#' @param pop A `<dist_spec>` giving the initial susceptible population size.
#' Used to adjust Rt estimates based on the proportion of the population that
#' is susceptible. Defaults to `Fixed(0)` which means no population adjustment
#' is done.
#'
#' @param pop_within_horizon Logical, defaults to `FALSE`. Should the
#' susceptible population adjustment be applied within the time horizon of data
#' or just beyond it. Note that if `pop_within_horizon = TRUE` the Rt estimate
#' will be unadjusted for susceptible depletion but the resulting posterior
#' predictions for infections and cases will be adjusted for susceptible
#' depletion.
#'
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the
#' Gaussian process, if in use, should be applied to Rt. Currently supported
Expand Down Expand Up @@ -354,13 +361,13 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
use_breakpoints = TRUE,
future = "latest",
gp_on = c("R_t-1", "R0"),
pop = 0) {
pop = Fixed(0),
pop_within_horizon = FALSE) {
rt <- list(
use_rt = use_rt,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
pop = pop,
gp_on = arg_match(gp_on)
)

Expand Down Expand Up @@ -388,6 +395,24 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
prior <- LogNormal(mean = prior$mean, sd = prior$sd)
}

if (is.numeric(pop)) {
lifecycle::deprecate_warn(
"1.7.0",
"rt_opts(pop = 'must be a `<dist_spec>`')",
details = "For specifying a fixed population size, use `Fixed(pop)`"
)
pop <- Fixed(pop)
}
rt$pop <- pop
if (isTRUE(pop_within_horizon) && pop == Fixed(0)) {
cli_abort(
c(
"!" = "pop_within_horizon = TRUE but pop is fixed at 0."
)
)
}
rt$estimate_pop <- TRUE

if (rt$use_rt) {
rt$prior <- prior
} else {
Expand Down Expand Up @@ -698,7 +723,10 @@ obs_opts <- function(family = c("negbin", "poisson"),
cli_abort(
c(
"!" = "Specifying {.var phi} as a vector of length 2 is deprecated.",
"i" = "Mean and SD should be given as list elements."
"i" = paste0(
"Use a {.cls dist_spec} instead, e.g. Normal(mean = {phi[1]}, ",
"sd = {phi[2]})."
)
)
)
}
Expand Down
9 changes: 5 additions & 4 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ simulate_infections <- function(estimates, R, initial_infections,
CrIs = c(0.2, 0.5, 0.9),
backend = "rstan",
seeding_time = NULL,
pop = 0, ...) {
pop = Fixed(0), ...) {
## deprecated usage
if (!missing(estimates)) {
deprecate_stop(
Expand All @@ -86,14 +86,14 @@ simulate_infections <- function(estimates, R, initial_infections,
assert_numeric(R$R, lower = 0)
assert_numeric(initial_infections, lower = 0)
assert_numeric(day_of_week_effect, lower = 0, null.ok = TRUE)
assert_numeric(pop, lower = 0)
if (!is.null(seeding_time)) {
assert_integerish(seeding_time, lower = 1)
}
assert_class(delays, "delay_opts")
assert_class(truncation, "trunc_opts")
assert_class(obs, "obs_opts")
assert_class(generation_time, "generation_time_opts")
assert_class(pop, "dist_spec")

## create R for all dates modelled
all_dates <- data.table(date = seq.Date(min(R$date), max(R$date), by = "day"))
Expand Down Expand Up @@ -125,7 +125,7 @@ simulate_infections <- function(estimates, R, initial_infections,
initial_infections = array(log_initial_infections, dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
use_pop = as.integer(pop != Fixed(0))
)

data <- c(data, create_stan_delays(
Expand Down Expand Up @@ -178,7 +178,8 @@ simulate_infections <- function(estimates, R, initial_infections,
alpha = NULL,
R0 = NULL,
frac_obs = obs$scale,
rep_phi = obs$phi
rep_phi = obs$phi,
pop = pop
))
## set empty params matrix - variable parameters not supported here
data$params <- array(dim = c(1, 0))
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ int<lower = 0> alpha_id; // parameter id of alpha (GP magnitude)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> frac_obs_id; // parameter id of frac_obs
int<lower = 0> rep_phi_id; // parameter id of rep_phi_id
int<lower = 0> pop_id; // parameter id of pop
2 changes: 1 addition & 1 deletion inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
int use_pop; // use population size
int<lower = 0> gt_id; // id of generation time
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
int use_pop; // use population size

int<lower = 0> gt_id; // id of generation time
6 changes: 5 additions & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ transformed parameters {
);
}
profile("infections") {
real pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
use_pop, future_time
);
}
} else {
Expand Down
8 changes: 4 additions & 4 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht) {
real pop, int use_pop, int ht) {
// time indices and storage
int ot = num_elements(oR);
int nht = ot - ht;
Expand All @@ -39,20 +39,20 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
}
}
// calculate cumulative infections
if (pop) {
if (use_pop) {
cum_infections[1] = sum(infections[1:uot]);
}
// iteratively update infections
for (s in 1:ot) {
infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s);
if (pop && s > nht) {
if (use_pop == 2 && s > nht) {
exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht]));
exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt;
infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt);
}else{
infections[s + uot] = R[s] * infectiousness[s];
}
if (pop && s < ot) {
if (use_pop && s < ot) {
cum_infections[s + 1] = cum_infections[s] + infections[s + uot];
}
}
Expand Down
8 changes: 7 additions & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ generated quantities {
frac_obs_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

vector[n] pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand All @@ -62,7 +68,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time
initial_growth[i], pop[i], use_pop, future_time
));

if (delay_id) {
Expand Down
2 changes: 1 addition & 1 deletion man/EpiNow2-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions man/rt_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/simulate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-create_rt_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test_that("create_rt_data returns expected default values", {
expect_equal(result$breakpoints, numeric(0))
expect_equal(result$future_fixed, 1)
expect_equal(result$fixed_from, 0)
expect_equal(result$pop, 0)
expect_equal(result$use_pop, 0)
expect_equal(result$stationary, 0)
expect_equal(result$future_time, 0)
})
Expand All @@ -27,13 +27,13 @@ test_that("create_rt_data handles custom rt_opts correctly", {
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
pop = Normal(mean = 1000000, sd = 100)
)

result <- create_rt_data(rt = custom_rt, horizon = 7)

expect_equal(result$estimate_r, 0)
expect_equal(result$pop, 1000000)
expect_equal(result$use_pop, 1)
expect_equal(result$stationary, 1)
expect_equal(result$future_time, 7)
})
Expand Down
16 changes: 11 additions & 5 deletions tests/testthat/test-rt_opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test_that("rt_opts returns expected default values", {
expect_equal(result$rw, 0)
expect_true(result$use_breakpoints)
expect_equal(result$future, "latest")
expect_equal(result$pop, 0)
expect_equal(result$pop, Fixed(0))
expect_equal(result$gp_on, "R_t-1")
})

Expand All @@ -19,18 +19,25 @@ test_that("rt_opts handles custom inputs correctly", {
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
pop = Normal(mean = 1000000, sd = 100)
))

expect_null(result$prior)
expect_false(result$use_rt)
expect_equal(result$rw, 7)
expect_true(result$use_breakpoints) # Should be TRUE when rw > 0
expect_equal(result$future, "project")
expect_equal(result$pop, 1000000)
expect_equal(result$pop, Normal(mean = 1000000, sd = 100))
expect_equal(result$gp_on, "R0")
})

test_that("rt_opts warns when pop is passed as numeric", {
expect_warning(
rt_opts(pop = 1000),
"Specifying `pop` as a numeric value is deprecated"
)
})

test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", {
result <- rt_opts(rw = 3, use_breakpoints = FALSE)
expect_true(result$use_breakpoints)
Expand Down Expand Up @@ -59,8 +66,7 @@ test_that("rt_opts returns object of correct class", {
})

test_that("rt_opts handles edge cases correctly", {
result <- rt_opts(rw = 0.1, pop = -1)
result <- rt_opts(rw = 0.1)
expect_equal(result$rw, 0.1)
expect_equal(result$pop, -1)
expect_true(result$use_breakpoints)
})
Loading
Loading