Skip to content

Commit

Permalink
distribution interface to dist_spec (#504)
Browse files Browse the repository at this point in the history
* add distribution functions

* deprecate "empty" distribution

* make sd S3

* only generate samples if any params aren't natural

* update stan model with new dist interface

* update lognormal parameters

* return mean function to previous functionality

* update data

* deprecate dist_def functions

* use natural parametrisations in dist_def functions

* deprecate dist_spec

* extract_single_dist function

* update fix_dist to work with compsosite dists

* extract squash

* update parameters to extract

* specify lower bounds in function

* pass lower bounds to stan model

* update sample/report functions

* max squash adjust report

* update dist functions to new syntax

* re-create data

* update get_dist to new syntax

* fully deprecate get fnuctions

* create delay inits separately

* max squash again

* return correct dist in estimate_truncation

* few more examples/docs

* fix tests

* add documentation to dist interface

* add input checks

* sd function to work with composite dists

* warn when not using natural parameters

* ensure bounds are respected in stan

* add empty distribution for legacy reasons

* add checks to dist_skel

* use lapply for parameters

* don't calculate sd if length 1

* use uncertain reporting in example

* don't add one to sd

* return correct parameters

* dist_skel: calculate rate everywhere

* update dist_skel examples

* add missing man file

* don't run internal examples

* demote warning to message

* update syntax everywhere

* add news item

* turn sd into an internal function

* fix distribution documentation

* remove obselete default

* spell checking

* use correct sd function

* linting

* remove obsolete tests

* loop over all parameters

* update touchstone arguments

* linting

* fix regex search/replace gone wrong

* remove obsolete space

* update strategy for estimating uncertainty

* update uncertain parameter transformations

* add missing sd to parameter sampling

* update / recompile vignettes

* update var names

* rename argument in docs

* update man pages

* update test result

* add reviewer

Co-authored-by: Sam Abbott <[email protected]>

* base scaling on variance, not sd

* re-render vignettes

* full text capitalisation of distributions

* separate dist_spec from stan model

* adjust tests/code for new dist_spec set up

* re-create examples

* re-doc

* update tests

* new dist_spec in estimate_truncation example

* update get_seeding_time with updated dist_spec

* estimate_truncation and seeding time tests

* update truncation dist in estimate_truncation

* remove more uses of old dist_spec

* SD explicitly to zero for fixed

* give names

* fix typo

* fix indent

* fix another typo

* squash bugs highlighted by tests

* remove missing variable

* linting

* add missing docs

* import transpose

* ensure sd is positive

* fix estimate_truncation example

* make tolerance user-settable

* use purrr::map instead of lapply

* fix stan dist test

* fix plotting

* Apply suggestions from code review

Co-authored-by: Sam Abbott <[email protected]>

* rate and scale examples for Gamma

Co-authored-by: Sam Abbott <[email protected]>

* capitalise gamma and lognormal

Co-authored-by: Sam Abbott <[email protected]>

* change to single hash

* use bar in normal_cdf

* remove estraneous backticks

* remove space before left parenthesis

* split up dist.R

* move deprecated `dist_spec` function

* add examples

* initial design sketch

* make parameter conversion more flexible

* add test for alternative gama params

* update syntax in simulate_infections

* add missing tag

* update man pages

* update estimate_secondary tests

* update simulate_infections for new interface

* udpate snapshots

* get_dist deprecation test with natural params

* update phi syntax

* hide internal example

* update deprecations

* use toString

* pmf -> NonParametric

* add american spelling

* fix gamma deprecation

* add new functions to pkgdown

* update vignette

* recompile vignettes

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Mar 12, 2024
1 parent 3a3d8d8 commit 19b5707
Show file tree
Hide file tree
Showing 122 changed files with 3,774 additions and 2,590 deletions.
14 changes: 14 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("+",dist_spec)
S3method(c,dist_spec)
S3method(max,dist_spec)
S3method(mean,dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
Expand All @@ -10,16 +12,23 @@ S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(Fixed)
export(Gamma)
export(LogNormal)
export(NonParametric)
export(Normal)
export(R_to_growth)
export(add_day_of_week)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
export(bootstrapped_dist_fit)
export(calc_CrI)
export(calc_CrIs)
export(calc_summary_measures)
export(calc_summary_stats)
export(clean_nowcasts)
export(collapse)
export(construct_output)
export(convert_to_logmean)
export(convert_to_logsd)
Expand All @@ -35,6 +44,8 @@ export(create_shifted_cases)
export(create_stan_args)
export(create_stan_data)
export(delay_opts)
export(discretise)
export(discretize)
export(dist_fit)
export(dist_skel)
export(dist_spec)
Expand Down Expand Up @@ -195,16 +206,19 @@ importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
importFrom(purrr,flatten)
importFrom(purrr,keep)
importFrom(purrr,list_transpose)
importFrom(purrr,map)
importFrom(purrr,map2_dbl)
importFrom(purrr,map_chr)
importFrom(purrr,map_dbl)
importFrom(purrr,map_dfc)
importFrom(purrr,pmap_dbl)
importFrom(purrr,quietly)
importFrom(purrr,reduce)
importFrom(purrr,safely)
importFrom(purrr,transpose)
importFrom(purrr,walk)
importFrom(rlang,abort)
importFrom(rlang,arg_match)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs.
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
* The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk.
* The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs.

## Documentation

Expand Down
232 changes: 227 additions & 5 deletions R/adjust.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ adjust_infection_to_report <- function(infections, delay_defs,
# Reset DT Defaults on Exit
set_dt_single_thread()

## deprecated
sample_single_dist <- function(input, delay_def) {
## Define sample delay fn
sample_delay_fn <- function(n, ...) {
Expand All @@ -111,14 +112,50 @@ adjust_infection_to_report <- function(infections, delay_defs,
return(out)
}

report <- sample_single_dist(infections, delay_defs[[1]])

if (length(delay_defs) > 1) {
for (def in 2:length(delay_defs)) {
report <- sample_single_dist(report, delay_defs[[def]])
sample_dist_spec <- function(input, delay_def) {
## Define sample delay fn
sample_delay_fn <- function(n, dist, cum, ...) {
fixed_dist <- discretise(fix_dist(delay_def, strategy = "sample"))
if (dist) {
fixed_dist[[1]]$pmf[n + 1]
} else {
sample(seq_along(fixed_dist[[1]]$pmf) - 1, size = n, replace = TRUE)
}
}

## Infection to onset
out <- EpiNow2::sample_approx_dist(
cases = input,
dist_fn = sample_delay_fn,
max_value = max(delay_def),
direction = "forwards",
type = type,
truncate_future = FALSE
)

return(out)
}

if (is(delay_defs, "dist_spec")) {
report <- sample_dist_spec(infections, extract_single_dist(delay_defs, 1))
if (length(delay_defs) > 1) {
for (def in seq(2, length(delay_defs))) {
report <- sample_dist_spec(report, extract_single_dist(delay_defs, def))
}
}
} else {
deprecate_warn(
"1.5.0",
"adjust_infection_to_report(delay_defs = 'should be a dist_spec')",
details = "Specifying this as a list of data tables is deprecated."
)
report <- sample_single_dist(infections, delay_defs[[1]])
if (length(delay_defs) > 1) {
for (def in 2:length(delay_defs)) {
report <- sample_single_dist(report, delay_defs[[def]])
}
}
}
## Add a weekly reporting effect if present
if (!missing(reporting_effect)) {
reporting_effect <- data.table::data.table(
Expand Down Expand Up @@ -146,3 +183,188 @@ adjust_infection_to_report <- function(infections, delay_defs,
}
return(report)
}

#' Approximate Sampling a Distribution using Counts
#'
#' @description `r lifecycle::badge("soft-deprecated")`
#' Convolves cases by a PMF function. This function will soon be removed or
#' replaced with a more robust stan implementation.
#'
#' @param cases A `<data.frame>` of cases (in date order) with the following
#' variables: `date` and `cases`.
#'
#' @param max_value Numeric, maximum value to allow. Defaults to 120 days
#'
#' @param direction Character string, defato "backwards". Direction in which to
#' map cases. Supports either "backwards" or "forwards".
#'
#' @param dist_fn Function that takes two arguments with the first being
#' numeric and the second being logical (and defined as `dist`). Should return
#' the probability density or a sample from the defined distribution. See
#' the examples for more.
#'
#' @param earliest_allowed_mapped A character string representing a date
#' ("2020-01-01"). Indicates the earliest allowed mapped value.
#'
#' @param type Character string indicating the method to use to transform
#' counts. Supports either "sample" which approximates sampling or "median"
#' would shift by the median of the distribution.
#'
#' @param truncate_future Logical, should cases be truncated if they occur
#' after the first date reported in the data. Defaults to `TRUE`.
#'
#' @return A `<data.table>` of cases by date of onset
#' @export
#' @importFrom purrr map_dfc
#' @importFrom data.table data.table setorder
#' @importFrom lubridate days
#' @examples
#' \donttest{
#' cases <- example_confirmed
#' cases <- cases[, cases := as.integer(confirm)]
#' print(cases)
#'
#' # total cases
#' sum(cases$cases)
#'
#' delay_fn <- function(n, dist, cum) {
#' if (dist) {
#' pgamma(n + 0.9999, 2, 1) - pgamma(n - 1e-5, 2, 1)
#' } else {
#' as.integer(rgamma(n, 2, 1))
#' }
#' }
#'
#' onsets <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn
#' )
#'
#' # estimated onset distribution
#' print(onsets)
#'
#' # check that sum is equal to reported cases
#' total_onsets <- median(
#' purrr::map_dbl(
#' 1:10,
#' ~ sum(sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn
#' )$cases)
#' )
#' )
#' total_onsets
#'
#'
#' # map from onset cases to reported
#' reports <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn,
#' direction = "forwards"
#' )
#'
#'
#' # map from onset cases to reported using a mean shift
#' reports <- sample_approx_dist(
#' cases = cases,
#' dist_fn = delay_fn,
#' direction = "forwards",
#' type = "median"
#' )
#' }
sample_approx_dist <- function(cases = NULL,
dist_fn = NULL,
max_value = 120,
earliest_allowed_mapped = NULL,
direction = "backwards",
type = "sample",
truncate_future = TRUE) {
if (type == "sample") {
if (direction == "backwards") {
direction_fn <- rev
} else if (direction == "forwards") {
direction_fn <- function(x) {
x
}
}
# reverse cases so starts with current first
reversed_cases <- direction_fn(cases$cases)
reversed_cases[is.na(reversed_cases)] <- 0
# draw from the density fn of the dist
draw <- dist_fn(0:max_value, dist = TRUE, cum = FALSE)

# approximate cases
mapped_cases <- do.call(cbind, purrr::map(
seq_along(reversed_cases),
~ c(
rep(0, . - 1),
stats::rbinom(
length(draw),
rep(reversed_cases[.], length(draw)),
draw
),
rep(0, length(reversed_cases) - .)
)
))


# set dates order based on direction mapping
if (direction == "backwards") {
dates <- seq(min(cases$date) - lubridate::days(length(draw) - 1),
max(cases$date),
by = "days"
)
} else if (direction == "forwards") {
dates <- seq(min(cases$date),
max(cases$date) + lubridate::days(length(draw) - 1),
by = "days"
)
}

# summarises movements and sample for placement of non-integer cases
case_sum <- direction_fn(rowSums(mapped_cases))
floor_case_sum <- floor(case_sum)
sample_cases <- floor_case_sum +
as.numeric((runif(seq_along(case_sum)) < (case_sum - floor_case_sum)))

# summarise imputed onsets and build output data.table
mapped_cases <- data.table::data.table(
date = dates,
cases = sample_cases
)

# filter out all zero cases until first recorded case
mapped_cases <- data.table::setorder(mapped_cases, date)
mapped_cases <- mapped_cases[
,
cum_cases := cumsum(cases)
][cum_cases != 0][, cum_cases := NULL]
} else if (type == "median") {
shift <- as.integer(
median(as.integer(dist_fn(1000, dist = FALSE)), na.rm = TRUE)
)

if (direction == "backwards") {
mapped_cases <- data.table::copy(cases)[
,
date := date - lubridate::days(shift)
]
} else if (direction == "forwards") {
mapped_cases <- data.table::copy(cases)[
,
date := date + lubridate::days(shift)
]
}
}

if (!is.null(earliest_allowed_mapped)) {
mapped_cases <- mapped_cases[date >= as.Date(earliest_allowed_mapped)]
}

# filter out future cases
if (direction == "forwards" && truncate_future) {
max_date <- max(cases$date)
mapped_cases <- mapped_cases[date <= max_date]
}
return(mapped_cases)
}
50 changes: 50 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,53 @@ check_reports_valid <- function(reports, model) {
assert_numeric(reports$confirm, lower = 0)
}
}

#' Validate probability distribution for passing to stan
#'
#' @description
#' `check_stan_delay()` checks that the supplied data is a `<dist_spec>`,
#' that it is a supported distribution, and that is has a finite maximum.
#'
#' @param dist A `dist_spec` object.`
#' @importFrom checkmate assert_class
#' @importFrom rlang arg_match
#' @return Called for its side effects.
#' @keywords internal
check_stan_delay <- function(dist) {
# Check that `dist` is a `dist_spec`
assert_class(dist, "dist_spec")
# Check that `dist` is lognormal or gamma or nonparametric
distributions <- vapply(dist, function(x) x$distribution, character(1))
if (
!all(distributions %in% c("lognormal", "gamma", "fixed", "nonparametric"))
) {
stop(
"Distributions passed to the model need to be lognormal, gamma, fixed ",
"or nonparametric."
)
}
# Check that `dist` has parameters that are either numeric or normal
# distributions with numeric parameters and infinite maximum
numeric_parameters <- vapply(dist$parameters, is.numeric, logical(1))
normal_parameters <- vapply(
dist$parameters,
function(x) {
is(x, "dist_spec") &&
x$distribution == "normal" &&
all(vapply(x$parameters, is.numeric, logical(1))) &&
is.infinite(x$max)
},
logical(1)
)
if (!all(numeric_parameters | normal_parameters)) {
stop(
"Delay distributions passed to the model need to have parameters that ",
"are either numeric or normally distributed with numeric parameters ",
"and infinite maximum."
)
}
# Check that `dist` has a finite maximum
if (any(is.infinite(max(dist)))) {
stop("All distribution passed to the model need to have a finite maximum")
}
}
Loading

0 comments on commit 19b5707

Please sign in to comment.