Skip to content

Commit

Permalink
add support to stan
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Dec 19, 2024
1 parent b5063cb commit 6d22158
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 10 deletions.
6 changes: 4 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
breakpoints = breakpoints,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
pop = as.integer(rt$pop != Fixed(0)),
use_pop = as.integer(rt$pop != Fixed(0)),
stationary = as.numeric(rt$gp_on == "R0"),
future_time = horizon - future_rt$from
)
Expand Down Expand Up @@ -567,11 +567,13 @@ create_stan_data <- function(data, seeding_time,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
pop = rt$pop,
lower_bounds = c(
alpha = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
rep_phi = 0,
pop = 0
)
)
)
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ int<lower = 0> alpha_id; // parameter id of alpha (GP magnitude)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> frac_obs_id; // parameter id of frac_obs
int<lower = 0> rep_phi_id; // parameter id of rep_phi_id
int<lower = 0> pop_id; // parameter id of pop
4 changes: 2 additions & 2 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
int<lower = 0> gt_id; // id of generation time
int use_pop; // use population size
int<lower = 0> gt_id; // id of generation time
6 changes: 5 additions & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ transformed parameters {
);
}
profile("infections") {
real pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
use_pop, future_time
);
}
} else {
Expand Down
8 changes: 4 additions & 4 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht) {
real pop, int use_pop, int ht) {
// time indices and storage
int ot = num_elements(oR);
int nht = ot - ht;
Expand All @@ -39,20 +39,20 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
}
}
// calculate cumulative infections
if (pop) {
if (use_pop) {
cum_infections[1] = sum(infections[1:uot]);
}
// iteratively update infections
for (s in 1:ot) {
infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s);
if (pop && s > nht) {
if (use_pop && s > nht) {
exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht]));
exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt;
infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt);
}else{
infections[s + uot] = R[s] * infectiousness[s];
}
if (pop && s < ot) {
if (use_pop && s < ot) {
cum_infections[s + 1] = cum_infections[s] + infections[s + uot];
}
}
Expand Down
8 changes: 7 additions & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ generated quantities {
frac_obs_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

real pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand All @@ -62,7 +68,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time
initial_growth[i], pop, use_pop, future_time
));

if (delay_id) {
Expand Down

0 comments on commit 6d22158

Please sign in to comment.