Skip to content

Commit

Permalink
Allow specifying tolerance at to dist_spec upon definition (#724)
Browse files Browse the repository at this point in the history
* flexibly specify bounds on distributions

* make max/tolerance arguments explicit

* fix S3 documentation

* update new_dist_spec use

* update plot documentation

* create man files

* correctly name var

* remove obsolete return statements

* add global variable

* add news item

* simplify map syntax

* remove unused variables

* add distribution to globals

* fix x-axis

* add examples

* add tests for specifying tolerance

* add comment

* improve error message [ci skip]

Co-authored-by: James Azam <[email protected]>

* remove superseded comment

* clarify NA return values

* add informative error message

* remove superseded comment

* add [is_constrained()] function

* remove unneeded checks

* update pkgdown

* set default tolerance in _opts functions

* make deprecated functions internal

* update tolerance argument

* remove whitespace

* improve printout

* improve documentation

* clarify id doc

* use correct function name

---------

Co-authored-by: James Azam <[email protected]>
  • Loading branch information
sbfnk and jamesmbaazam authored Aug 1, 2024
1 parent 8476e3a commit ee4043d
Show file tree
Hide file tree
Showing 44 changed files with 912 additions and 583 deletions.
16 changes: 13 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@

S3method("+",dist_spec)
S3method(c,dist_spec)
S3method(discretise,dist_spec)
S3method(discretise,multi_dist_spec)
S3method(fix_dist,dist_spec)
S3method(fix_dist,multi_dist_spec)
S3method(is_constrained,dist_spec)
S3method(is_constrained,multi_dist_spec)
S3method(max,dist_spec)
S3method(max,multi_dist_spec)
S3method(mean,dist_spec)
S3method(mean,multi_dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
S3method(plot,estimate_infections)
S3method(plot,estimate_secondary)
S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(sd,default)
S3method(sd,dist_spec)
S3method(sd,multi_dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(Fixed)
Expand All @@ -19,9 +30,9 @@ export(NonParametric)
export(Normal)
export(R_to_growth)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
export(bootstrapped_dist_fit)
export(bound_dist)
export(calc_CrI)
export(calc_CrIs)
export(calc_summary_measures)
Expand All @@ -35,8 +46,6 @@ export(delay_opts)
export(discretise)
export(discretize)
export(dist_fit)
export(dist_skel)
export(dist_spec)
export(epinow)
export(epinow2_cmdstan_model)
export(estimate_delay)
Expand All @@ -63,6 +72,7 @@ export(get_regional_results)
export(gp_opts)
export(growth_to_R)
export(gt_opts)
export(is_constrained)
export(lognorm_dist_def)
export(make_conf)
export(map_prob_change)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
## Model changes

- `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk.
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument

## Bug fixes

Expand Down
28 changes: 15 additions & 13 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ check_stan_delay <- function(dist) {
# Check that `dist` is a `dist_spec`
assert_class(dist, "dist_spec")
# Check that `dist` is lognormal or gamma or nonparametric
distributions <- vapply(dist, function(x) x$distribution, character(1))
distributions <- vapply(
seq_len(ndist(dist)), get_distribution, x = dist, FUN.VALUE = character(1)
)
if (
!all(distributions %in% c("lognormal", "gamma", "fixed", "nonparametric"))
) {
Expand All @@ -78,24 +80,24 @@ check_stan_delay <- function(dist) {
}
# Check that `dist` has parameters that are either numeric or normal
# distributions with numeric parameters and infinite maximum
numeric_parameters <- vapply(dist$parameters, is.numeric, logical(1))
normal_parameters <- vapply(
dist$parameters,
function(x) {
is(x, "dist_spec") &&
x$distribution == "normal" &&
all(vapply(x$parameters, is.numeric, logical(1))) &&
is.infinite(x$max)
},
logical(1)
)
if (!all(numeric_parameters | normal_parameters)) {
numeric_or_normal <- unlist(lapply(seq_len(ndist(dist)), function(id) {
params <- get_parameters(dist, id)
vapply(params, function(x) {
is.numeric(x) ||
(is(x, "dist_spec") && get_distribution(x) == "normal" &&
is.infinite(max(x)))
}, logical(1))
}))
if (!all(numeric_or_normal)) {
stop(
"Delay distributions passed to the model need to have parameters that ",
"are either numeric or normally distributed with numeric parameters ",
"and infinite maximum."
)
}
if (is.null(attr(dist, "tolerance"))) {
attr(dist, "tolerance") <- 0
}
assert_numeric(attr(dist, "tolerance"), lower = 0, upper = 1)
# Check that `dist` has a finite maximum
if (any(is.infinite(max(dist))) && !(attr(dist, "tolerance") > 0)) {
Expand Down
31 changes: 16 additions & 15 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -741,33 +741,33 @@ create_stan_delays <- function(..., time_points = 1L) {
delays <- list(...)
## discretise
delays <- map(delays, discretise, strict = FALSE)
## convolve where appropriate
delays <- map(delays, collapse)
## apply tolerance
delays <- map(delays, function(x) {
apply_tolerance(x, tolerance = attr(x, "tolerance"))
})
## get maximum delays
max_delay <- unname(as.numeric(flatten(map(delays, max))))
## number of different non-empty types
type_n <- lengths(delays)
type_n <- vapply(delays, ndist, integer(1))
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

flat_delays <- flatten(delays)
## create "flat version" of delays, i.e. a list of all the delays (including
## elements of composite delays)
if (length(delays) > 1) {
flat_delays <- do.call(c, delays)
} else {
flat_delays <- delays
}
parametric <- unname(vapply(
flat_delays, function(x) x$distribution != "nonparametric", logical(1)
flat_delays, function(x) get_distribution(x) != "nonparametric", logical(1)
))
param_length <- unname(vapply(flat_delays[parametric], function(x) {
length(x$parameters)
length(get_parameters(x))
}, numeric(1)))
nonparam_length <- unname(vapply(flat_delays[!parametric], function(x) {
length(x$pmf)
}, numeric(1)))
distributions <- unname(as.character(
map(flat_delays[parametric], ~ .x$distribution)
map(flat_delays[parametric], get_distribution)
))

## create stan object
Expand All @@ -788,15 +788,16 @@ create_stan_delays <- function(..., time_points = 1L) {
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)

ret$params_mean <- array(unname(as.numeric(
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), mean)
map(flatten(map(flat_delays[parametric], get_parameters)), mean)
)))
ret$params_sd <- array(unname(as.numeric(
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), sd_dist)
map(flatten(map(flat_delays[parametric], get_parameters)), sd)
)))
ret$params_sd[is.na(ret$params_sd)] <- 0
ret$max <- array(max_delay[parametric])

ret$np_pmf <- array(unname(as.numeric(
flatten(map(flat_delays[!parametric], ~ .x$pmf))
flatten(map(flat_delays[!parametric], get_pmf))
)))
## get non zero length delay pmf lengths
ret$np_pmf_groups <- array(c(0, cumsum(nonparam_length)) + 1)
Expand All @@ -809,7 +810,7 @@ create_stan_delays <- function(..., time_points = 1L) {
## set lower bounds
ret$params_lower <- array(unname(as.numeric(flatten(
map(flat_delays[parametric], function(x) {
lower_bounds(x$distribution)[names(x$parameters)]
lower_bounds(get_distribution(x))[names(get_parameters(x))]
})
))))
## assign prior weights
Expand Down
85 changes: 35 additions & 50 deletions R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ adjust_infection_to_report <- function(infections, delay_defs,
#' @param fixed Deprecated, use [fix_dist()] instead.
#' @return A list of distribution options.
#' @importFrom rlang warn arg_match
#' @export
#' @keywords internal
dist_spec <- function(distribution = c(
"lognormal", "normal", "gamma", "fixed", "empty"
Expand Down Expand Up @@ -485,55 +484,7 @@ rstan_opts <- function(object = NULL,
#' Samples outside of this range are resampled.
#'
#' @return A vector of samples or a probability distribution.
#' @export
#' @examples
#'
#' ## Exponential model
#' # sample
#' dist_skel(10, model = "exp", params = list(rate = 1))
#'
#' # cumulative prob density
#' dist_skel(1:10, model = "exp", dist = TRUE, params = list(rate = 1))
#'
#' # probability density
#' dist_skel(1:10,
#' model = "exp", dist = TRUE,
#' cum = FALSE, params = list(rate = 1)
#' )
#'
#' ## Gamma model
#' # sample
#' dist_skel(10, model = "gamma", params = list(shape = 1, rate = 0.5))
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' params = list(shape = 1, rate = 0.5)
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' cum = FALSE, params = list(shape = 2, rate = 0.5)
#' )
#'
#' ## Log normal model
#' # sample
#' dist_skel(10,
#' model = "lognormal", params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE, cum = FALSE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
#' @keywords internal
dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
discrete = FALSE, params, max_value = 120) {
lifecycle::deprecate_warn(
Expand Down Expand Up @@ -633,3 +584,37 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
sample <- truncated_skel(n, dist = dist, cum = cum, max_value = max_value)
return(sample)
}

#' Applies a threshold to all nonparametric distributions in a <dist_spec>
#'
#' @description `r lifecycle::badge("deprecated")`
#' This function is deprecated. Use `bound_dist()` instead.
#' @param x A `<dist_spec>`
#' @param tolerance Numeric; the desired tolerance level. Any part of the
#' cumulative distribution function beyond 1 minus this tolerance level is
#' removed.
#' @return A `<dist_spec>` where probability masses below the threshold level
#' have been removed
#' @keywords internal
apply_tolerance <- function(x, tolerance) {
lifecycle::deprecate_warn(
"1.6.0", "apply_tolerance()", "bound_dist()"
)
if (!is(x, "dist_spec")) {
stop("Can only apply tolerance to distributions in a <dist_spec>.")
}
y <- lapply(x, function(x) {
if (x$distribution == "nonparametric") {
cmf <- cumsum(x$pmf)
new_pmf <- x$pmf[c(TRUE, (1 - cmf[-length(cmf)]) >= tolerance)]
x$pmf <- new_pmf / sum(new_pmf)
return(x)
} else {
return(x)
}
})

## preserve attributes
attributes(y) <- attributes(x)
return(y)
}
Loading

0 comments on commit ee4043d

Please sign in to comment.