Skip to content

Commit

Permalink
tests for cmdstanr backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Feb 2, 2024
1 parent 232b8ea commit 2414824
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/testthat/test-epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ test_that("epinow produces expected output when run with default settings", {
expect_equal(names(out$plots), c("infections", "reports", "R", "growth_rate", "summary"))
})

test_that("epinow produces expected output when run with the
cmdstanr backend", {
output <- capture.output(suppressMessages(suppressWarnings(
out <- epinow(
reported_cases = reported_cases,
generation_time = generation_time_opts(example_generation_time),
delays = delay_opts(c(example_incubation_period, reporting_delay)),
stan = stan_opts(
samples = 25, warmup = 25,
cores = 1, chains = 2,
control = list(adapt_delta = 0.8),
backend = "cmdstanr"
),
logs = NULL, verbose = FALSE
)
)))

expect_equal(names(out), expected_out)
df_non_zero(out$estimates$samples)
df_non_zero(out$estimates$summarised)
df_non_zero(out$estimated_reported_cases$samples)
df_non_zero(out$estimated_reported_cases$summarised)
df_non_zero(out$summary)
expect_equal(names(out$plots), c("infections", "reports", "R", "growth_rate", "summary"))
})

test_that("epinow runs without error when saving to disk", {
expect_null(suppressWarnings(epinow(
reported_cases = reported_cases,
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ inc <- estimate_secondary(cases[1:60],
verbose = FALSE
)

output <- capture.output(suppressMessages(suppressWarnings(
inc_cmdstanr <- estimate_secondary(cases[1:60],
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE),
verbose = FALSE, stan = stan_opts(backend = "cmdstanr")
)
)))

# extract posterior variables of interest
params <- c(
"meanlog" = "delay_mean[1]", "sdlog" = "delay_sd[1]",
"scaling" = "frac_obs[1]"
)

inc_posterior <- inc$posterior[variable %in% params]
inc_posterior_cmdstanr <- inc_cmdstanr$posterior[variable %in% params]

#### Prevalence data example ####

Expand Down Expand Up @@ -92,6 +100,18 @@ test_that("estimate_secondary can recover simulated parameters", {
)
})

test_that("estimate_secondary can recover simulated parameters with the
cmdstanr backend", {
expect_equal(
inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4),
tolerance = 0.1
)
expect_equal(
inc_posterior_cmdstanr[, median], c(1.8, 0.5, 0.4),
tolerance = 0.1
)
})

test_that("forecast_secondary can return values from simulated data and plot
them", {
inc_preds <- forecast_secondary(inc, cases[seq(61, .N)][, value := primary])
Expand All @@ -100,6 +120,16 @@ test_that("forecast_secondary can return values from simulated data and plot
expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA)
})

test_that("forecast_secondary can return values from simulated data when using
the cmdstanr backend", {
capture.output(suppressMessages(suppressWarnings(
inc_preds <- forecast_secondary(
inc_cmdstanr, cases[seq(61, .N)][, value := primary], backend = "cmdstanr"
)
)))
expect_equal(names(inc_preds), c("samples", "forecast", "predictions"))
})

test_that("estimate_secondary works with weigh_delay_priors = TRUE", {
delays <- dist_spec(
mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ test_that("estimate_truncation can return values from simulated data and plot
expect_error(plot(est), NA)
})

test_that("estimate_truncation can return values from simulated data with the
cmdstanr backend", {
# fit model to example data
output <- capture.output(suppressMessages(suppressWarnings(
est <- estimate_truncation(example_data,
verbose = FALSE, chains = 2, iter = 1000, warmup = 250,
stan = stan_opts(backend = "cmdstanr")
))))
expect_equal(
names(est),
c("dist", "obs", "last_obs", "cmf", "data", "fit")
)
expect_s3_class(est$dist, "dist_spec")
expect_error(plot(est), NA)
})

test_that("deprecated arguments are recognised", {
options(warn = 2)
expect_error(estimate_truncation(example_data,
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ test_that("simulate_infections works to simulate a passed in estimate_infections
expect_equal(names(sims), c("samples", "summarised", "observations"))
})

test_that("simulate_infections works to simulate a passed in estimate_infections
object when using the cmdstanr backend", {
output <- capture.output(suppressMessages(suppressWarnings(
sims <- simulate_infections(out, backend = "cmdstanr")
)))
expect_equal(names(sims), c("samples", "summarised", "observations"))
})

test_that("simulate_infections works to simulate a passed in estimate_infections object with an adjusted Rt", {
R <- c(rep(NA_real_, 40), rep(0.5, 17))
sims <- simulate_infections(out, R)
Expand Down

0 comments on commit 2414824

Please sign in to comment.