Skip to content

Commit

Permalink
make tolerance user-settable
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Jan 12, 2024
1 parent 05f3d71 commit 8fc0f77
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 14 deletions.
11 changes: 6 additions & 5 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -669,17 +669,18 @@ create_stan_args <- function(stan = stan_opts(),
##'
##' @param ... Named delay distributions. The names are assigned to IDs
##' @param weight Numeric, weight associated with delay priors; default: 1
##' @inheritParams apply_tolerance
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map flatten
##' @author Sebastian Funk
create_stan_delays <- function(..., weight = 1, tolerance = 0.001) {
create_stan_delays <- function(..., weight = 1) {
## discretise
delays <- lapply(list(...), discretise)
delays <- map(list(...), discretise)
## convolve where appropriate
delays <- lapply(delays, collapse)
delays <- map(delays, collapse)
## apply tolerance
delays <- lapply(delays, apply_tolerance, tolerance = tolerance)
delays <- map(delays, function(x) {
apply_tolerance(x, tolerance = attr(x, "tolerance"))
})
## get maximum delays
max_delay <- unname(as.numeric(flatten(map(delays, max))))
## number of different non-empty types
Expand Down
12 changes: 9 additions & 3 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#' @param prior_weight deprecated; prior weights are now specified as a
#' model option. Use the `weigh_delay_priors` argument of
#' [estimate_infections()] instead.
#' @inheritParams apply_tolerance
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
#' @author Sebastian Funk
Expand Down Expand Up @@ -41,7 +42,7 @@
#' generation_time_opts(example_generation_time)
generation_time_opts <- function(dist = Fixed(1), ...,
disease, source, max = 14, fixed = FALSE,
prior_weight) {
prior_weight, tolerance = 0.001) {
deprecated_options_given <- FALSE
dot_options <- list(...)

Expand Down Expand Up @@ -99,6 +100,7 @@ generation_time_opts <- function(dist = Fixed(1), ...,
"`?generation_time_opts`")
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "class") <- c("generation_time_opts", class(dist))
return(dist)
}
Expand All @@ -112,6 +114,7 @@ generation_time_opts <- function(dist = Fixed(1), ...,
#' a fixed distribution with all mass at 0, i.e. no delay.
#' @param ... deprecated; use `dist` instead
#' @param fixed deprecated; use `dist` instead
#' @inheritParams apply_tolerance
#' @return A `<delay_opts>` object summarising the input delay distributions.
#' @author Sam Abbott
#' @author Sebastian Funk
Expand All @@ -132,7 +135,7 @@ generation_time_opts <- function(dist = Fixed(1), ...,
#'
#' # Multiple delays (in this case twice the same)
#' delay_opts(delay + delay)
delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE) {
delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE, tolerance = 0.001) {
dot_options <- list(...)
if (!is(dist, "dist_spec")) { ## could be old syntax
if (is.list(dist)) {
Expand Down Expand Up @@ -164,6 +167,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE) {
stop("Unknown named arguments passed to `delay_opts`")
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "class") <- c("delay_opts", class(dist))
return(dist)
}
Expand All @@ -178,6 +182,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE) {
#' @param dist A delay distribution or series of delay distributions reflecting
#' the truncation generated using [dist_spec()] or [estimate_truncation()].
#' Default is fixed distribution with maximum 0, i.e. no truncation
#' @inheritParams apply_tolerance
#' @return A `<trunc_opts>` object summarising the input truncation
#' distribution.
#'
Expand All @@ -192,7 +197,7 @@ delay_opts <- function(dist = Fixed(0), ..., fixed = FALSE) {
#'
#' # truncation dist
#' trunc_opts(dist = dist_spec(mean = 3, sd = 2, max = 10))
trunc_opts <- function(dist = Fixed(0)) {
trunc_opts <- function(dist = Fixed(0), tolerance = 0.001) {
if (!is(dist, "dist_spec")) {
if (is.list(dist)) {
dist <- do.call(dist_spec, dist)
Expand All @@ -209,6 +214,7 @@ trunc_opts <- function(dist = Fixed(0)) {
)
}
check_stan_delay(dist)
attr(dist, "tolerance") <- tolerance
attr(dist, "class") <- c("trunc_opts", class(dist))
return(dist)
}
Expand Down
4 changes: 1 addition & 3 deletions man/create_stan_delays.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/delay_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/generation_time_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/trunc_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8fc0f77

Please sign in to comment.