From 6d221589376580eb472d65a3f15b261e0230f459 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Dec 2024 17:43:43 +0000 Subject: [PATCH] add support to stan --- R/create.R | 6 ++++-- inst/stan/data/estimate_infections_params.stan | 1 + inst/stan/data/rt.stan | 4 ++-- inst/stan/estimate_infections.stan | 6 +++++- inst/stan/functions/infections.stan | 8 ++++---- inst/stan/simulate_infections.stan | 8 +++++++- 6 files changed, 23 insertions(+), 10 deletions(-) diff --git a/R/create.R b/R/create.R index 88da040c8..9c8f9b818 100644 --- a/R/create.R +++ b/R/create.R @@ -280,7 +280,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, breakpoints = breakpoints, future_fixed = as.numeric(future_rt$fixed), fixed_from = future_rt$from, - pop = as.integer(rt$pop != Fixed(0)), + use_pop = as.integer(rt$pop != Fixed(0)), stationary = as.numeric(rt$gp_on == "R0"), future_time = horizon - future_rt$from ) @@ -567,11 +567,13 @@ create_stan_data <- function(data, seeding_time, R0 = rt$prior, frac_obs = obs$scale, rep_phi = obs$phi, + pop = rt$pop, lower_bounds = c( alpha = 0, R0 = 0, frac_obs = 0, - rep_phi = 0 + rep_phi = 0, + pop = 0 ) ) ) diff --git a/inst/stan/data/estimate_infections_params.stan b/inst/stan/data/estimate_infections_params.stan index 3351f5ea3..5db9976f0 100644 --- a/inst/stan/data/estimate_infections_params.stan +++ b/inst/stan/data/estimate_infections_params.stan @@ -2,3 +2,4 @@ int alpha_id; // parameter id of alpha (GP magnitude) int R0_id; // parameter id of R0 int frac_obs_id; // parameter id of frac_obs int rep_phi_id; // parameter id of rep_phi_id +int pop_id; // parameter id of pop diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index b736f1ade..07dadc036 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -5,5 +5,5 @@ array[t - seeding_time] int breakpoints; // when do breakpoints occur int future_fixed; // is underlying future Rt assumed to be fixed int fixed_from; // Reference date for when Rt estimation should be fixed - int pop; // Initial susceptible population - int gt_id; // id of generation time + int use_pop; // use population size + int gt_id; // id of generation time \ No newline at end of file diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 8202c962c..5b8b23329 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -95,9 +95,13 @@ transformed parameters { ); } profile("infections") { + real pop = get_param( + pop_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); infections = generate_infections( R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop, - future_time + use_pop, future_time ); } } else { diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index b7790c582..ed0d3bf5a 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf, // generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections) vector generate_infections(vector oR, int uot, vector gt_rev_pmf, array[] real initial_infections, array[] real initial_growth, - int pop, int ht) { + real pop, int use_pop, int ht) { // time indices and storage int ot = num_elements(oR); int nht = ot - ht; @@ -39,20 +39,20 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, } } // calculate cumulative infections - if (pop) { + if (use_pop) { cum_infections[1] = sum(infections[1:uot]); } // iteratively update infections for (s in 1:ot) { infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s); - if (pop && s > nht) { + if (use_pop && s > nht) { exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht])); exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt; infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt); }else{ infections[s + uot] = R[s] * infectiousness[s]; } - if (pop && s < ot) { + if (use_pop && s < ot) { cum_infections[s + 1] = cum_infections[s] + infections[s + uot]; } } diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 3e8131994..b3cd57985 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -50,6 +50,12 @@ generated quantities { frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value, params ); + + real pop = get_param( + pop_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + for (i in 1:n) { // generate infections from Rt trace vector[delay_type_max[gt_id] + 1] gt_rev_pmf; @@ -62,7 +68,7 @@ generated quantities { infections[i] = to_row_vector(generate_infections( to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - initial_growth[i], pop, future_time + initial_growth[i], pop, use_pop, future_time )); if (delay_id) {