Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dist_spec interface #363

Merged
merged 106 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
497dbf8
implement `dist_spec` interface
sbfnk Feb 7, 2023
2de8c1f
Add skipping of stan tests in the expected places
seabbs Feb 8, 2023
31729e5
fix checks
sbfnk Feb 8, 2023
5cd962e
fix epinow example
sbfnk Feb 8, 2023
01be1c4
update syntax in more places
sbfnk Feb 8, 2023
47fb373
update another example
sbfnk Feb 8, 2023
d4b2d67
move parenthesis to the right place
sbfnk Feb 8, 2023
aa4101e
update uncertainty in estimate_infections example
sbfnk Feb 8, 2023
97074a8
fix use of generation_time_opts as resource to call get_generation_ti…
seabbs Feb 8, 2023
45bf995
break line to make R CMD CHECK happy
sbfnk Feb 8, 2023
ac5613a
fix uses of `trunc_opts`
sbfnk Feb 8, 2023
a7b025b
fix updating of `cur_len` in ragged convolution
sbfnk Feb 8, 2023
e65a654
remove bounds on mean parameters
sbfnk Feb 8, 2023
b282270
use dist_spec syntax in `estimate_delays`
sbfnk Feb 9, 2023
3c9ee68
add print function for `dist_spec`
sbfnk Feb 9, 2023
40e2586
add names to printing if given
sbfnk Feb 9, 2023
59976d7
clarify printouts
sbfnk Feb 9, 2023
cae4b4f
reduce unnecessary function calls
sbfnk Feb 13, 2023
28d46a5
Revert "reduce unnecessary function calls"
sbfnk Feb 13, 2023
e5d2c35
fix typo
sbfnk Feb 9, 2023
c77df4a
add default option for generation time
sbfnk Mar 2, 2023
024be5c
simplify delay inits
sbfnk Mar 3, 2023
1374338
update pmf doc
sbfnk Mar 3, 2023
522f12b
dist -> distribution
sbfnk Mar 3, 2023
4d295e3
fix max of np dist logic
sbfnk Mar 3, 2023
dfefb72
simplify pmf truncation syntax
sbfnk Mar 3, 2023
f02b321
fix typos and use `is`
sbfnk Apr 27, 2023
bcefdf7
extract function for stan code conversion
sbfnk Mar 3, 2023
c765caf
fix variable name
sbfnk Mar 3, 2023
00d948e
fix function name
sbfnk Mar 3, 2023
f307aa9
do truncnorm with appropriate lengths
sbfnk Mar 3, 2023
d2682ea
fix initial condition sampling
sbfnk Mar 3, 2023
cfd039b
update `to_stan` documentation
sbfnk Mar 3, 2023
f458a30
fix typo
sbfnk Mar 3, 2023
239bb69
stan model with unified delays
sbfnk Apr 25, 2023
3e19d26
update R access to unified dist interface
sbfnk Apr 27, 2023
7a00103
update tests
sbfnk Apr 27, 2023
600c4a0
ensure arrays are arrays
sbfnk Apr 27, 2023
e372866
simplify stan seq (and avoid conflict with R)
sbfnk Apr 27, 2023
42b374d
fix test
sbfnk Apr 28, 2023
3da2fd4
fix simulation models
sbfnk Apr 28, 2023
2df8a68
fix final tests
sbfnk Apr 28, 2023
9dae328
update usage of c -> +
sbfnk Apr 28, 2023
02f70fd
Automatic readme update
actions-user Apr 28, 2023
3d256b7
update examples/doc and re-doc
sbfnk Apr 28, 2023
1deace7
linting
sbfnk Apr 28, 2023
1eacb5c
update docs
sbfnk Apr 28, 2023
dc321a1
final requested lint
sbfnk Apr 28, 2023
d0932fb
update return type of bootstrapped_dist_fit
sbfnk Apr 28, 2023
cc6a2d9
redoc
sbfnk Apr 28, 2023
563502a
update estimate_delay to reflect changes
sbfnk Apr 28, 2023
66d5a55
dot product for all convolutions
sbfnk May 2, 2023
b6186c6
report gt mean and var
sbfnk May 2, 2023
610b1db
bug fix in calculation of max delays
sbfnk May 2, 2023
47c9e16
Automatic readme update
actions-user May 2, 2023
094771e
update tests
sbfnk May 2, 2023
62d7b5a
clean whitespace
sbfnk May 2, 2023
2969307
reduce number of calculations by precomputing len
sbfnk May 2, 2023
849f358
optional head/tail
sbfnk May 10, 2023
7435597
Revert "optional head/tail"
sbfnk May 10, 2023
dde5534
don't convolve first pmf
sbfnk May 10, 2023
f18054f
reduce vector copying
sbfnk May 11, 2023
e4c3524
fix reversing
sbfnk May 12, 2023
bbd1e4b
fix printing of combined distributions
sbfnk May 12, 2023
529567a
add exampples, export, and add basic dist plotting
seabbs Jun 8, 2023
28ebd6a
Automatic readme update
actions-user Jun 7, 2023
8430bc0
add some tests for dist_spec
seabbs Jun 8, 2023
f4a08c7
add tests for +.dist_spec
seabbs Jun 8, 2023
9ccb739
add tests for mean.dist_spec
seabbs Jun 8, 2023
5e6e04f
add some basic additional tests and docs
seabbs Jun 8, 2023
721799a
linting
seabbs Jun 8, 2023
f90cefa
fix linting
seabbs Jun 8, 2023
a8402f4
export c
seabbs Jun 8, 2023
bef8fc2
fix plotting to work with c() method for dist_spec
seabbs Jun 8, 2023
44f8510
more linting fixes
seabbs Jun 9, 2023
89180e4
remove extract line in generation_time.stan
seabbs Jun 9, 2023
028853f
add a check in convolve_rev_pmf when len >= xlen + ylen and update tests
seabbs Jun 9, 2023
724dd65
be more efficient when calc discrete pmfs
seabbs Jun 9, 2023
b0aeb65
catch missing indexes in omf calc
seabbs Jun 9, 2023
7b72e31
clarify comments
seabbs Jun 9, 2023
be61d97
add tolerance for +.dist_spec
seabbs Jun 9, 2023
c8d70a0
don't load testthat
seabbs Jun 9, 2023
a46c9a9
trigger benchmarking
seabbs Jun 9, 2023
66aa1f6
remove benchmark trigger
seabbs Jun 9, 2023
dee73a2
linting
seabbs Jun 9, 2023
022923f
Automatic readme update
actions-user Jun 9, 2023
efd096c
trigger benchmarking
seabbs Jun 9, 2023
5d38f35
remove benchmark trigger
seabbs Jun 9, 2023
e5ef461
refine tolerance checks for convolution
seabbs Jun 9, 2023
b26b626
fix example
seabbs Jun 9, 2023
ea753c6
add an internal function
seabbs Jun 9, 2023
4613147
trigger benchmarking
seabbs Jun 9, 2023
706629e
benchmarking off
seabbs Jun 9, 2023
254d9b4
add back in missing tolerance docs
seabbs Jun 12, 2023
3c123c4
fix edge case check for length 1 pmfs
seabbs Jun 12, 2023
2c81f97
whitespace linting
seabbs Jun 12, 2023
f4a85c0
test more carefully
seabbs Jun 12, 2023
b8e044e
use commas like a smart boy
seabbs Jun 12, 2023
727445a
crank that adapt delta handle
seabbs Jun 12, 2023
9ecca7e
Update R/create.R
seabbs Jun 13, 2023
400d3d8
Update R/get.R
seabbs Jun 13, 2023
742744e
Update R/opts.R
seabbs Jun 13, 2023
2e8c592
Update R/dist.R
seabbs Jun 13, 2023
967d3cd
fixed @internal and brackets + fcase
seabbs Jun 13, 2023
749dc8e
don't export c.dist_spec
seabbs Jun 13, 2023
27722b0
drop c() examle from plot
seabbs Jun 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
^\.devcontainer$
^CRAN-SUBMISSION$
^touchstone$
^\.benchmark$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ inst/include/*.o
src

.DS_Store
.vscode
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method("+",dist_spec)
S3method(mean,dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
S3method(plot,estimate_infections)
S3method(plot,estimate_secondary)
S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(R_to_growth)
Expand Down Expand Up @@ -145,6 +149,7 @@ importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_ribbon)
importFrom(ggplot2,geom_step)
importFrom(ggplot2,geom_vline)
importFrom(ggplot2,ggplot)
importFrom(ggplot2,ggplot_build)
Expand All @@ -157,6 +162,7 @@ importFrom(ggplot2,scale_x_date)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,theme)
importFrom(ggplot2,theme_bw)
importFrom(ggplot2,vars)
importFrom(lifecycle,deprecate_soft)
importFrom(lifecycle,deprecate_warn)
importFrom(lubridate,days)
Expand Down Expand Up @@ -188,6 +194,7 @@ importFrom(rstan,summary)
importFrom(rstan,vb)
importFrom(runner,mean_run)
importFrom(scales,comma)
importFrom(stats,convolve)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,median)
Expand All @@ -207,5 +214,6 @@ importFrom(stats,sd)
importFrom(stats,var)
importFrom(truncnorm,rtruncnorm)
importFrom(utils,capture.output)
importFrom(utils,head)
importFrom(utils,tail)
useDynLib(EpiNow2, .registration=TRUE)
22 changes: 12 additions & 10 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

This release is in development. For a stable release install 1.3.5 from CRAN.

## Breaking changes

- The external distribution interface has been updated to use the `dist_spec()` function. This comes with a range of benefits, including optimising model fitting when static delays are used (by convolving when first defined vs in stan), easy printing (using `print()`), and easy plotting (using `plot()`). It also makes it possible to use all supported distributions everywhere (i.e, as a generation time or reporting delay). However, this update will break most users code as the interface has changed. See the documentation for `dist_spec()` for more details. By @sbfnk in #363 and reviewed by @seabbs.

## Package

* Model description has been expanded to include more detail.
* Moved to a GitHub Action to only lint changed files.
* Linted the package with a wider range of default linters.
* Added a GitHub Action to build the README when it is altered.
* Added handling of edge case where we sample from the negative binomial with
mean close or equal to 0. By @sbfnk in #366.
* Replaced use of nested `ifelse()` and `data.table::fifelse()` in the
code base with `data.table::fcase()`. By @jamesmbaazam in #383 and reviewed by @seabbs.
* Reviewed the example in `calc_backcalc_data()` to call `calc_backcalc_data()`
instead of `create_gp_data()`. By @jamesmbaazam in #388 and reviewed by @seabbs.
* Model description has been expanded to include more detail. By @sbfnk in #373 and reviewed by @seabbs.
* Moved to a GitHub Action to only lint changed files. By @seabbs in #378.
* Linted the package with a wider range of default linters. By @seabbs in #378.
* Added a GitHub Action to build the README when it is altered. By @seabbs.
* Added handling of edge case where we sample from the negative binomial with mean close or equal to 0. By @sbfnk in #366 and reviewed by @seabbs.
* Replaced use of nested `ifelse()` and `data.table::fifelse()` in the code base with `data.table::fcase()`. By @jamesmbaazam in #383 and reviewed by @seabbs.
* Reviewed the example in `calc_backcalc_data()` to call `calc_backcalc_data()` instead of `create_gp_data()`. By @jamesmbaazam in #388 and reviewed by @seabbs.
* Improved compilation times by reducing the number of distinct stan models and deprecated `tune_inv_gamma()`. By @sbfnk in #394 and reviewed by @seabbs.
* Changed touchstone settings so that benchmarks are only performed if the stan model is changed. By @sbfnk in #400 and reviewed by @seabbs.
* [pak](https://pak.r-lib.org/) is now suggested for installing the developmental version of the package. By @jamesmbaazam in #407 and reviewed by @seabbs. This has been successfully tested on MacOS Ventura, Ubuntu 20.04, and Windows 10. Users are advised to use `remotes::install_github("epiforecasts/EpiNow2")` if `pak` fails and if both fail, raise an issue.
* `dist_fit()`'s `samples` argument now takes a default value of 1000 instead of NULL. If a supplied `samples` is less than 1000, it is changed to 1000 and a warning is thrown to indicate the change. By @jamesmbazam in #389 and reviewed by @seabbs.
* The internal distribution interface has been streamlined to reduce code duplication. By @sbfnk in #363 and reviewed by @seabbs.

# EpiNow2 1.3.5

Expand Down
154 changes: 82 additions & 72 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,8 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#'
#' @param shifted_cases A dataframe of delay shifted cases
#'
#' @param truncation `r lifecycle::badge("experimental")` A list of options as
#' generated by `trunc_opts()` defining the truncation of observed data.
#' Defaults to `trunc_opts()`. See `estimate_truncation()` for an approach to
#' estimating truncation from data.
#' @param seeding_time Integer; seeding time, usually obtained using
#' `get_seeding_time()`
#'
#' @inheritParams create_gp_data
#' @inheritParams create_obs_model
Expand All @@ -430,34 +428,20 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @author Sam Abbott
#' @author Sebastian Funk
#' @export
create_stan_data <- function(reported_cases, generation_time,
rt, gp, obs, delays, horizon,
backcalc, shifted_cases,
truncation) {
## make sure we have at least gt_max seeding time
delays$seeding_time <- max(delays$seeding_time, generation_time$max)
create_stan_data <- function(reported_cases, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {

## for backwards compatibility call generation_time_opts internally
if (is.list(generation_time) &&
all(c("mean", "mean_sd", "sd", "sd_sd") %in% names(generation_time))) {
generation_time <- do.call(generation_time_opts, generation_time)
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm

data <- list(
cases = cases,
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
burn_in = 0
burn_in = 0,
seeding_time = seeding_time
)
# add gt data
data <- c(data, generation_time)
# add delay data
data <- c(data, delays)
# add truncation data
data <- c(data, truncation)
# add Rt data
data <- c(
data,
Expand All @@ -476,10 +460,6 @@ create_stan_data <- function(reported_cases, generation_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (is.null(data$gt_weight)) {
## default: weigh by number of data points
data$gt_weight <- data$t - data$seeding_time - data$horizon
}
if (data$seeding_time > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
Expand Down Expand Up @@ -532,37 +512,20 @@ create_stan_data <- function(reported_cases, generation_time,
create_initial_conditions <- function(data) {
init_fun <- function() {
out <- list()
if (data$n_uncertain_mean_delays > 0) {
out$delay_mean <- array(purrr::map2_dbl(
data$delay_mean_mean[data$uncertain_mean_delays],
data$delay_mean_sd[data$uncertain_mean_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
if (data$delay_n_p > 0) {
out$delay_mean <- array(truncnorm::rtruncnorm(
n = data$delay_n_p, a = 0,
mean = data$delay_mean_mean, sd = data$delay_mean_sd * 0.1
))
}
if (data$n_uncertain_sd_delays > 0) {
out$delay_sd <- array(purrr::map2_dbl(
data$delay_sd_mean[data$uncertain_sd_delays],
data$delay_sd_sd[data$uncertain_sd_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
out$delay_sd <- array(truncnorm::rtruncnorm(
n = data$delay_n_p, a = 0,
mean = data$delay_sd_mean, sd = data$delay_sd_sd * 0.1
))
} else {
out$delay_mean <- array(numeric(0))
out$delay_sd <- array(numeric(0))
}
if (data$truncation > 0) {
if (data$trunc_mean_sd > 0) {
out$truncation_mean <- array(rnorm(1,
mean = data$trunc_mean_mean,
sd = data$trunc_mean_sd * 0.1
))
}
if (data$trunc_sd_sd > 0) {
out$truncation_sd <- array(
truncnorm::rtruncnorm(1,
a = 0,
mean = data$trunc_sd_mean,
sd = data$trunc_sd_sd * 0.1
)
)
}
}

if (data$fixed == 0) {
out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1))
out$rho <- array(rlnorm(1,
Expand All @@ -579,6 +542,10 @@ create_initial_conditions <- function(data) {
out$alpha <- array(
truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)
)
} else {
out$eta <- array(numeric(0))
out$rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
Expand All @@ -597,30 +564,23 @@ create_initial_conditions <- function(data) {
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
if (data$gt_mean_sd > 0) {
out$gt_mean <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_mean_mean,
sd = data$gt_mean_sd * 0.1
))
}
if (data$gt_sd_sd > 0) {
out$gt_sd <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_sd_mean,
sd = data$gt_sd_sd * 0.1
))
}
}

if (data$bp_n > 0) {
out$bp_sd <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = 0.1))
out$bp_effects <- array(rnorm(data$bp_n, 0, 0.1))
}
if (data$bp_n > 0) {
out$bp_sd <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = 0.1))
out$bp_effects <- array(rnorm(data$bp_n, 0, 0.1))
} else {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale == 1) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd * 0.1
))
} else {
out$frac_obs <- array(numeric(0))
}
if (data$week_effect > 0) {
out$day_of_week_simplex <- array(
Expand Down Expand Up @@ -675,3 +635,53 @@ create_stan_args <- function(stan = stan_opts(),
args$return_fit <- NULL
return(args)
}

##' Create delay variables for stan
##'
##' @param ... Named delay distributions specified using `dist_spec()`.
##' The names are assigned to IDs
##' @param ot Integer, number of observations (needed if weighing any priors)
##' with the number of observations
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map
##' @author Sebastian Funk
create_stan_delays <- function(..., ot) {
seabbs marked this conversation as resolved.
Show resolved Hide resolved
dot_args <- list(...)
## combine delays
combined_delays <- unclass(c(...))
## number of different non-empty types
type_n <- unlist(purrr::transpose(dot_args)$n)
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

## start consructing stan object
ret <- unclass(combined_delays)
## construct additional variables
ret <- c(ret, list(
types = sum(type_n > 0),
types_p = array(1L - combined_delays$fixed)
))
## delay identifiers
ret$types_id <- integer(0)
ret$types_id[ret$types_p == 1] <- seq_len(ret$n_p)
ret$types_id[ret$types_p == 0] <- seq_len(ret$n_np)
ret$types_id <- array(ret$types_id)
## map delays to identifiers
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)
## map pmfs
ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1)
## assign prior weights
if (any(ret$weight == 0)) {
ret$weight[ret$weight == 0] <- ot
}
## remove auxiliary variables
ret$fixed <- NULL
ret$np_pmf_length <- NULL

names(ret) <- paste("delay", names(ret), sep = "_")
ret <- c(ret, ids)

return(ret)
}
Loading