From 2414824c47da4a16957c9e78216f76721fccc97d Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 1 Feb 2024 22:14:37 +0000 Subject: [PATCH] tests for cmdstanr backend --- tests/testthat/test-epinow.R | 26 ++++++++++++++++++++ tests/testthat/test-estimate_secondary.R | 30 +++++++++++++++++++++++ tests/testthat/test-estimate_truncation.R | 16 ++++++++++++ tests/testthat/test-simulate_infections.R | 8 ++++++ 4 files changed, 80 insertions(+) diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index 8dc6bd340..fe481b909 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -38,6 +38,32 @@ test_that("epinow produces expected output when run with default settings", { expect_equal(names(out$plots), c("infections", "reports", "R", "growth_rate", "summary")) }) +test_that("epinow produces expected output when run with the + cmdstanr backend", { + output <- capture.output(suppressMessages(suppressWarnings( + out <- epinow( + reported_cases = reported_cases, + generation_time = generation_time_opts(example_generation_time), + delays = delay_opts(c(example_incubation_period, reporting_delay)), + stan = stan_opts( + samples = 25, warmup = 25, + cores = 1, chains = 2, + control = list(adapt_delta = 0.8), + backend = "cmdstanr" + ), + logs = NULL, verbose = FALSE + ) + ))) + + expect_equal(names(out), expected_out) + df_non_zero(out$estimates$samples) + df_non_zero(out$estimates$summarised) + df_non_zero(out$estimated_reported_cases$samples) + df_non_zero(out$estimated_reported_cases$summarised) + df_non_zero(out$summary) + expect_equal(names(out$plots), c("infections", "reports", "R", "growth_rate", "summary")) +}) + test_that("epinow runs without error when saving to disk", { expect_null(suppressWarnings(epinow( reported_cases = reported_cases, diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index a887c2381..ad9a586b8 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -25,6 +25,13 @@ inc <- estimate_secondary(cases[1:60], verbose = FALSE ) +output <- capture.output(suppressMessages(suppressWarnings( + inc_cmdstanr <- estimate_secondary(cases[1:60], + obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + verbose = FALSE, stan = stan_opts(backend = "cmdstanr") + ) +))) + # extract posterior variables of interest params <- c( "meanlog" = "delay_mean[1]", "sdlog" = "delay_sd[1]", @@ -32,6 +39,7 @@ params <- c( ) inc_posterior <- inc$posterior[variable %in% params] +inc_posterior_cmdstanr <- inc_cmdstanr$posterior[variable %in% params] #### Prevalence data example #### @@ -92,6 +100,18 @@ test_that("estimate_secondary can recover simulated parameters", { ) }) +test_that("estimate_secondary can recover simulated parameters with the + cmdstanr backend", { + expect_equal( + inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4), + tolerance = 0.1 + ) + expect_equal( + inc_posterior_cmdstanr[, median], c(1.8, 0.5, 0.4), + tolerance = 0.1 + ) +}) + test_that("forecast_secondary can return values from simulated data and plot them", { inc_preds <- forecast_secondary(inc, cases[seq(61, .N)][, value := primary]) @@ -100,6 +120,16 @@ test_that("forecast_secondary can return values from simulated data and plot expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA) }) +test_that("forecast_secondary can return values from simulated data when using + the cmdstanr backend", { + capture.output(suppressMessages(suppressWarnings( + inc_preds <- forecast_secondary( + inc_cmdstanr, cases[seq(61, .N)][, value := primary], backend = "cmdstanr" + ) + ))) + expect_equal(names(inc_preds), c("samples", "forecast", "predictions")) +}) + test_that("estimate_secondary works with weigh_delay_priors = TRUE", { delays <- dist_spec( mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30 diff --git a/tests/testthat/test-estimate_truncation.R b/tests/testthat/test-estimate_truncation.R index dad8f4b12..ae72d82f4 100644 --- a/tests/testthat/test-estimate_truncation.R +++ b/tests/testthat/test-estimate_truncation.R @@ -54,6 +54,22 @@ test_that("estimate_truncation can return values from simulated data and plot expect_error(plot(est), NA) }) +test_that("estimate_truncation can return values from simulated data with the + cmdstanr backend", { + # fit model to example data + output <- capture.output(suppressMessages(suppressWarnings( + est <- estimate_truncation(example_data, + verbose = FALSE, chains = 2, iter = 1000, warmup = 250, + stan = stan_opts(backend = "cmdstanr") + )))) + expect_equal( + names(est), + c("dist", "obs", "last_obs", "cmf", "data", "fit") + ) + expect_s3_class(est$dist, "dist_spec") + expect_error(plot(est), NA) +}) + test_that("deprecated arguments are recognised", { options(warn = 2) expect_error(estimate_truncation(example_data, diff --git a/tests/testthat/test-simulate_infections.R b/tests/testthat/test-simulate_infections.R index dbe55e677..951ddce6b 100644 --- a/tests/testthat/test-simulate_infections.R +++ b/tests/testthat/test-simulate_infections.R @@ -20,6 +20,14 @@ test_that("simulate_infections works to simulate a passed in estimate_infections expect_equal(names(sims), c("samples", "summarised", "observations")) }) +test_that("simulate_infections works to simulate a passed in estimate_infections + object when using the cmdstanr backend", { + output <- capture.output(suppressMessages(suppressWarnings( + sims <- simulate_infections(out, backend = "cmdstanr") + ))) + expect_equal(names(sims), c("samples", "summarised", "observations")) +}) + test_that("simulate_infections works to simulate a passed in estimate_infections object with an adjusted Rt", { R <- c(rep(NA_real_, 40), rep(0.5, 17)) sims <- simulate_infections(out, R)