Skip to content

Commit

Permalink
check works and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Dec 19, 2024
1 parent 6d22158 commit f93f380
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 22 deletions.
4 changes: 2 additions & 2 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ simulate_infections <- function(estimates, R, initial_infections,
CrIs = c(0.2, 0.5, 0.9),
backend = "rstan",
seeding_time = NULL,
pop = 0, ...) {
pop = Fixed(0), ...) {
## deprecated usage
if (!missing(estimates)) {
deprecate_stop(
Expand Down Expand Up @@ -125,7 +125,7 @@ simulate_infections <- function(estimates, R, initial_infections,
initial_infections = array(log_initial_infections, dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
use_pop = as.integer(pop != Fixed(0))
)

data <- c(data, create_stan_delays(
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int use_pop; // use population size
int<lower = 0> gt_id; // id of generation time
int<lower = 0> gt_id; // id of generation time
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
int use_pop; // use population size

int<lower = 0> gt_id; // id of generation time
4 changes: 2 additions & 2 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ generated quantities {
params_value, params
);

real pop = get_param(
vector[n] pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);
Expand All @@ -68,7 +68,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, use_pop, future_time
initial_growth[i], pop[i], use_pop, future_time
));

if (delay_id) {
Expand Down
2 changes: 1 addition & 1 deletion man/EpiNow2-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/rt_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/simulate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-create_rt_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ test_that("create_rt_data handles custom rt_opts correctly", {
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
pop = Normal(mean = 1000000, sd = 100)
)

result <- create_rt_data(rt = custom_rt, horizon = 7)

expect_equal(result$estimate_r, 0)
expect_equal(result$pop, 1000000)
expect_equal(result$use_pop, 1)
expect_equal(result$stationary, 1)
expect_equal(result$future_time, 7)
})
Expand Down
14 changes: 10 additions & 4 deletions tests/testthat/test-rt_opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,25 @@ test_that("rt_opts handles custom inputs correctly", {
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
pop = Normal(mean = 1000000, sd = 100)
))

expect_null(result$prior)
expect_false(result$use_rt)
expect_equal(result$rw, 7)
expect_true(result$use_breakpoints) # Should be TRUE when rw > 0
expect_equal(result$future, "project")
expect_equal(result$pop, 1000000)
expect_equal(result$pop, Normal(mean = 1000000, sd = 100))
expect_equal(result$gp_on, "R0")
})

test_that("rt_opts warns when pop is passed as numeric", {
expect_warning(
rt_opts(pop = 1000),
"Specifying `pop` as a numeric value is deprecated"
)
})

test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", {
result <- rt_opts(rw = 3, use_breakpoints = FALSE)
expect_true(result$use_breakpoints)
Expand Down Expand Up @@ -59,8 +66,7 @@ test_that("rt_opts returns object of correct class", {
})

test_that("rt_opts handles edge cases correctly", {
result <- rt_opts(rw = 0.1, pop = -1)
result <- rt_opts(rw = 0.1)
expect_equal(result$rw, 0.1)
expect_equal(result$pop, -1)
expect_true(result$use_breakpoints)
})
2 changes: 1 addition & 1 deletion vignettes/estimate_infections_options.Rmd.orig
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ dep <- estimate_infections(reported_cases,
delays = delay_opts(delay),
rt = rt_opts(
prior = rt_prior,
pop = 1000000, future = "latest"
pop = Normal(mean = 1000000, sd = 1000), future = "latest"
)
)
# summarise results
Expand Down

0 comments on commit f93f380

Please sign in to comment.