From 5dac0b71133e98c42a1ec527abf3aa2c7962c142 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 27 Aug 2024 14:15:12 +0100 Subject: [PATCH] Use broadcasting vs loop in day of week (#746) * use broadcasting vs loop * add tests for some obs models stan fns * Update NEWS.md --- NEWS.md | 1 + inst/stan/functions/observation_model.stan | 9 +---- tests/testthat/test-stan-observation_model.R | 37 ++++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) create mode 100644 tests/testthat/test-stan-observation_model.R diff --git a/NEWS.md b/NEWS.md index 175c03cd5..8aa164acf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,7 @@ - The interface for defining delay distributions has been generalised to also cater for continuous distributions - When defining probability distributions these can now be truncated using the `tolerance` argument - Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @. +- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam. - A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs. ## Bug fixes diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index f55210281..aa10ccfda 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -10,16 +10,9 @@ * @return A vector of reports adjusted for day of the week effects. */ vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect) { - int t = num_elements(reports); int wl = num_elements(effect); - // scale day of week effect vector[wl] scaled_effect = wl * effect; - vector[t] scaled_reports; - for (s in 1:t) { - // add reporting effects (adjust for simplex scale) - scaled_reports[s] = reports[s] * scaled_effect[day_of_week[s]]; - } - return(scaled_reports); + return reports .* scaled_effect[day_of_week]; } /** diff --git a/tests/testthat/test-stan-observation_model.R b/tests/testthat/test-stan-observation_model.R new file mode 100644 index 000000000..27a6c2207 --- /dev/null +++ b/tests/testthat/test-stan-observation_model.R @@ -0,0 +1,37 @@ +skip_on_cran() +skip_on_os("windows") + +test_that("day_of_week_effect applies day of week effect correctly", { + reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000) + day_of_week <- c(1, 2, 3, 1, 2, 3, 1, 2, 3, 1) + effect <- c(1.0, 1.1, 1.2) + + expected <- reports * effect[day_of_week] * 3 + result <- day_of_week_effect(reports, day_of_week, effect) + + expect_equal(result, expected) +}) + +test_that("scale_obs scales reports by fraction observed", { + reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000) + frac_obs <- 0.5 + + expected <- c(50, 100, 150, 200, 250, 300, 350, 400, 450, 500) + result <- scale_obs(reports, frac_obs) + + expect_equal(result, expected) +}) + +test_that("truncate_obs truncates reports correctly", { + reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000) + trunc_rev_cmf <- c(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1) + + expected_truncate <- c(100, 180, 240, 280, 300, 300, 280, 240, 180, 100) + result_truncate <- truncate_obs(reports, trunc_rev_cmf, reconstruct = 0) + + expect_equal(result_truncate, expected_truncate) + + result_reconstruct <- truncate_obs(expected_truncate, trunc_rev_cmf, reconstruct = 1) + + expect_equal(result_reconstruct, reports) +})