Skip to content

Commit

Permalink
Merge branch 'main' into vectorise-spectral-density
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Aug 27, 2024
2 parents abe77f6 + 5dac0b7 commit c8a8ca4
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 8 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Documentation

- The README has been updated to link to the free course on nowcasting and forecasting. The availability of variational inference, Laplace approximation, and Pathfinder through `cmdstanr` has also be surfaced. By @jamesmbaazam in #753 and reviewed by @seabbs.
- Some implicit argument defaults have been made explicit in the function definition. By @Bisaloo in #729.
- The installation guide in the README has been updated to provide instructions for configuring the C toolchain of Windows, MacOS, and Linux. By @jamesmbaazam in #707 and reviewed by @sbfnk.

Expand All @@ -16,6 +17,8 @@
- The default stan control options have been updated from `list(adapt_delta = 0.95, max_treedepth = 15)` to `list(adapt_delta = 0.9, max_treedepth = 12)` due to improved performance and to reduce the runtime of the default parameterisations. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Initialisation has been simplified by sampling directly from the priors, where possible, rather than from a constrained space. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Unnecessary normalisation of delay priors has been removed. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.

## Bug fixes

Expand Down
24 changes: 24 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,27 @@ check_stan_delay <- function(dist) {
)
}
}

#' Check that PMF tail is not sparse
#'
#' @description Checks if the tail of a PMF vector has more than `span`
#' consecutive values smaller than `tol` and throws a warning if so.
#' @param pmf A probability mass function vector
#' @param span The number of consecutive indices in the tail to check
#' @param tol The value which to consider the tail as sparse
#'
#' @return Called for its side effects.
#' @keywords internal
check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) {
if (all(pmf[(length(pmf) - span + 1):length(pmf)] < tol)) {
warning(
sprintf(
"The PMF tail has %s consecutive values smaller than %s.",
span, tol
),
" This will drastically increase run time with very small increases ",
"in accuracy. Consider increasing the tail values of the PMF.",
call. = FALSE
)
}
}
1 change: 1 addition & 0 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ Fixed <- function(value, ...) {
#' NonParametric(c(0.1, 0.3, 0.2, 0.4))
#' NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1))
NonParametric <- function(pmf, ...) {
check_sparse_pmf_tail(pmf)
params <- list(pmf = pmf / sum(pmf))
return(new_dist_spec(params, "nonparametric"))
}
Expand Down
4 changes: 4 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The default model in `estimate_infections()` uses a non-stationary Gaussian proc
* A deconvolution/back-calculation method for inferring infections, followed with calculating the time-varying reproduction number.
* Adjustment for the remaining susceptible population beyond the forecast horizon.

By default, all these models are fit with [MCMC sampling](https://mc-stan.org/docs/reference-manual/mcmc.html) using the [`rstan`](https://mc-stan.org/users/interfaces/rstan) R package as the backend. Users can, however, switch to use approximate algorithms like [variational inference](https://en.wikipedia.org/wiki/Variational_Bayesian_methods), the [pathfinder](https://mc-stan.org/docs/reference-manual/pathfinder.html) algorithm, or [Laplace approximation](https://mc-stan.org/docs/reference-manual/laplace.html) especially for quick prototyping. The latter two methods are provided through the [`cmdstanr`](https://mc-stan.org/cmdstanr/) R package, so users will have to install that separately.

The documentation for `estimate_infections` provides examples of the implementation of the different options available.

`{EpiNow2}` is designed to be used via a single function call to two functions:
Expand Down Expand Up @@ -114,6 +116,8 @@ is your quickest entry point to the package. It provides a quick run through of
the two main functions in the package and how to set up them up. It also
discusses how to summarise and visualise the results after running the models.

More broadly, users can also learn the details of estimating delay distributions, nowcasting, and forecasting in a structured way through the free and open short-course, ["Nowcasting and forecasting infectious disease dynamics"](https://nfidd.github.io/nfidd/), developed by some authors of this package.

</details>

<details> <summary> Package website </summary>
Expand Down
9 changes: 1 addition & 8 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,9 @@
* @return A vector of reports adjusted for day of the week effects.
*/
vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect) {
int t = num_elements(reports);
int wl = num_elements(effect);
// scale day of week effect
vector[wl] scaled_effect = wl * effect;
vector[t] scaled_reports;
for (s in 1:t) {
// add reporting effects (adjust for simplex scale)
scaled_reports[s] = reports[s] * scaled_effect[day_of_week[s]];
}
return(scaled_reports);
return reports .* scaled_effect[day_of_week];
}

/**
Expand Down
23 changes: 23 additions & 0 deletions man/check_sparse_pmf_tail.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,8 @@ test_that("check_reports_valid errors for bad 'secondary' specifications", {
# Run tests
test_col_specs(secondary_col_dt, model = "estimate_secondary")
})

test_that("check_sparse_pmf_tail throws a warning as expected", {
pmf <- c(0.4, 0.30, 0.20, 0.05, 0.049995, 4.5e-06, rep(1e-7, 5))
expect_warning(check_sparse_pmf_tail(pmf), "consecutive values smaller than")
})
37 changes: 37 additions & 0 deletions tests/testthat/test-stan-observation_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
skip_on_cran()
skip_on_os("windows")

test_that("day_of_week_effect applies day of week effect correctly", {
reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)
day_of_week <- c(1, 2, 3, 1, 2, 3, 1, 2, 3, 1)
effect <- c(1.0, 1.1, 1.2)

expected <- reports * effect[day_of_week] * 3
result <- day_of_week_effect(reports, day_of_week, effect)

expect_equal(result, expected)
})

test_that("scale_obs scales reports by fraction observed", {
reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)
frac_obs <- 0.5

expected <- c(50, 100, 150, 200, 250, 300, 350, 400, 450, 500)
result <- scale_obs(reports, frac_obs)

expect_equal(result, expected)
})

test_that("truncate_obs truncates reports correctly", {
reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)
trunc_rev_cmf <- c(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1)

expected_truncate <- c(100, 180, 240, 280, 300, 300, 280, 240, 180, 100)
result_truncate <- truncate_obs(reports, trunc_rev_cmf, reconstruct = 0)

expect_equal(result_truncate, expected_truncate)

result_reconstruct <- truncate_obs(expected_truncate, trunc_rev_cmf, reconstruct = 1)

expect_equal(result_reconstruct, reports)
})

0 comments on commit c8a8ca4

Please sign in to comment.