Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 757: Vectorise GP stan code #742

Merged
merged 107 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
9c3b8dd
refactor gps along approach used by aki
seabbs Aug 9, 2024
725cbb3
add some skeleton unit tests
seabbs Aug 9, 2024
20f76bf
update model
seabbs Aug 9, 2024
478e834
update data chunk
seabbs Aug 9, 2024
32572a4
add r interface code
seabbs Aug 9, 2024
8ed3295
add some R side tests where missing for gp related code
seabbs Aug 9, 2024
f20f409
get tests passing
seabbs Aug 12, 2024
89de657
fix tests for create_gp_data
seabbs Aug 12, 2024
0fbbcba
fix tests
seabbs Aug 12, 2024
17f9992
update docs
seabbs Aug 12, 2024
5b3c0b4
fix tests warnings to error for uncommon mattern orders
seabbs Aug 12, 2024
602f28c
remove spurious plots
seabbs Aug 12, 2024
1e5eb4a
update EpiNow2 vignette
seabbs Aug 12, 2024
0bf3605
review kkernals
seabbs Aug 12, 2024
4f8c5bc
update and test
seabbs Aug 12, 2024
54161f0
correct scaling of L
seabbs Aug 12, 2024
092a794
rescale lengthscale
seabbs Aug 13, 2024
a4b04a8
change adapt-delta default to 0.9
seabbs Aug 13, 2024
d8a1894
expand inits for GP as causing issues as to close to 1/0
seabbs Aug 13, 2024
ebcfb89
get rid of normalisation and use unormalised lpdf where possible (equiv)
seabbs Aug 13, 2024
4702388
widen optimisation sweep to include delay weight default
seabbs Aug 13, 2024
706b9de
non-center random walk
seabbs Aug 13, 2024
92624fa
tune prior specification
seabbs Aug 13, 2024
657c27d
tune dispersion prior
seabbs Aug 13, 2024
0701b7b
tune phi
seabbs Aug 13, 2024
561bdb3
update vignette
seabbs Aug 13, 2024
a1344d4
get rid of rw change
seabbs Aug 13, 2024
ca14c13
revert Rt
seabbs Aug 13, 2024
16354c0
catch update_rt
seabbs Aug 13, 2024
0da8c90
add news
seabbs Aug 13, 2024
c28af7a
revert vignette changes
seabbs Aug 13, 2024
eb663b3
fix gp_opts tests
seabbs Aug 13, 2024
6247bfa
fix create tests
seabbs Aug 13, 2024
e2f4b13
update gp tests
seabbs Aug 14, 2024
137ecf2
skip tests as required on windows
seabbs Aug 14, 2024
299268a
fix linting
seabbs Aug 14, 2024
1b70596
constrain delay uncertainty
seabbs Aug 14, 2024
128c592
correct gp stan tests
seabbs Aug 14, 2024
6fd326f
fix GP test
seabbs Aug 14, 2024
b9153b6
put the deprecition warning behind a gate
seabbs Aug 14, 2024
e338331
refactor gps along approach used by aki
seabbs Aug 9, 2024
9297afd
add some skeleton unit tests
seabbs Aug 9, 2024
8dec848
update model
seabbs Aug 9, 2024
b503e82
update data chunk
seabbs Aug 9, 2024
e96a4fc
add r interface code
seabbs Aug 9, 2024
3cfc9aa
add some R side tests where missing for gp related code
seabbs Aug 9, 2024
783a766
get tests passing
seabbs Aug 12, 2024
e296a89
fix tests for create_gp_data
seabbs Aug 12, 2024
fc15818
fix tests
seabbs Aug 12, 2024
5edd8ef
update docs
seabbs Aug 12, 2024
282de44
fix tests warnings to error for uncommon mattern orders
seabbs Aug 12, 2024
d333a81
update EpiNow2 vignette
seabbs Aug 12, 2024
9e990ce
review kkernals
seabbs Aug 12, 2024
afc2520
update and test
seabbs Aug 12, 2024
a01c7c4
correct scaling of L
seabbs Aug 12, 2024
19da24c
rescale lengthscale
seabbs Aug 13, 2024
0fb1dc6
change adapt-delta default to 0.9
seabbs Aug 13, 2024
c0d0dbb
expand inits for GP as causing issues as to close to 1/0
seabbs Aug 13, 2024
3ff3cb9
get rid of normalisation and use unormalised lpdf where possible (equiv)
seabbs Aug 13, 2024
9dba492
widen optimisation sweep to include delay weight default
seabbs Aug 13, 2024
7c16cd6
non-center random walk
seabbs Aug 13, 2024
5d3801b
tune prior specification
seabbs Aug 13, 2024
fee0289
tune dispersion prior
seabbs Aug 13, 2024
fea377c
tune phi
seabbs Aug 13, 2024
89cafd6
update vignette
seabbs Aug 13, 2024
190c800
get rid of rw change
seabbs Aug 13, 2024
9bd1989
revert Rt
seabbs Aug 13, 2024
7090af2
catch update_rt
seabbs Aug 13, 2024
da5c1ea
add news
seabbs Aug 13, 2024
38ae947
revert vignette changes
seabbs Aug 13, 2024
0348039
fix gp_opts tests
seabbs Aug 13, 2024
883d6cb
fix create tests
seabbs Aug 13, 2024
f476353
update gp tests
seabbs Aug 14, 2024
1f39fa3
skip tests as required on windows
seabbs Aug 14, 2024
30eb30f
fix linting
seabbs Aug 14, 2024
c6328e4
constrain delay uncertainty
seabbs Aug 14, 2024
de9348e
correct gp stan tests
seabbs Aug 14, 2024
c47fd0f
fix GP test
seabbs Aug 14, 2024
92db47c
put the deprecition warning behind a gate
seabbs Aug 14, 2024
d767118
fix linting
seabbs Aug 14, 2024
239a710
Update NEWS.md
seabbs Aug 14, 2024
40aae17
merge
seabbs Aug 15, 2024
f781c3c
add linear kernel support
seabbs Aug 15, 2024
fa7d2f4
add docs and newa
seabbs Aug 15, 2024
ecf8965
integration tests and minor issues
seabbs Aug 15, 2024
e0c3849
fixes for periodic kernel dimension differences
seabbs Aug 15, 2024
242ccb6
drop linear kernel support
seabbs Aug 19, 2024
4cb3de7
lint space
seabbs Aug 19, 2024
adf9412
catch outstanding linear tests
seabbs Aug 19, 2024
a46bc6e
catch stan tests
seabbs Aug 19, 2024
20292a7
make the eecdf in convolve test less random
seabbs Aug 20, 2024
abe77f6
Update R/create.R
seabbs Aug 27, 2024
c8a8ca4
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 27, 2024
1b73919
Update NEWS.md
seabbs Aug 28, 2024
a42f31d
Update create.R - remove out of date gp type 3 check
seabbs Aug 28, 2024
5c39a86
Update opts.R - remove linear kernel references
seabbs Aug 28, 2024
83f372f
Update R/opts.R
seabbs Aug 28, 2024
93bb8c5
Update opts.R - fix review suggestions
seabbs Aug 28, 2024
1074124
Update opts.R - remove linear reference
seabbs Aug 28, 2024
37be7ff
Update estimate_infections.stan
seabbs Aug 28, 2024
94df110
Update tests/testthat/test-create_gp_data.R
seabbs Aug 28, 2024
09a8abb
Update NEWS.md
seabbs Aug 28, 2024
078d550
Update opts.R
seabbs Aug 28, 2024
71eb3be
Document
actions-user Aug 28, 2024
13f7f8a
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
7ff2590
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
7532c83
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
- `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk.
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in #741 and reviewed by @seabbs.
- Gaussian processes have been vectorised, leading to some speed gains 🚀 , and the `gp_opts()` function has gained three more options, "periodic", "ou", and "se", to specify periodic and linear kernels respectively. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Prior predictive checks have been used to update the following priors: the prior on the magnitude of the Gaussian process (from HalfNormal(0, 1) to HalfNormal(0, 0.1)), and the prior on the overdispersion (from 1 / HalfNormal(0, 1)^2 to 1 / HalfNormal(0, 0.25)). In the user-facing API, this is a change in default values of the `sd` of `phi` in `obs_opts()` from 1 to 0.25. By @seabbs in #742 and reviewed by @jamesmbaazam.
- The default stan control options have been updated from `list(adapt_delta = 0.95, max_treedepth = 15)` to `list(adapt_delta = 0.9, max_treedepth = 12)` due to improved performance and to reduce the runtime of the default parameterisations. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Initialisation has been simplified by sampling directly from the priors, where possible, rather than from a constrained space. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Unnecessary normalisation of delay priors has been removed. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
Expand Down
80 changes: 50 additions & 30 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ create_backcalc_data <- function(backcalc = backcalc_opts()) {
)
return(data)
}

#' Create Gaussian Process Data
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -372,33 +373,50 @@ create_gp_data <- function(gp = gp_opts(), data) {
} else {
fixed <- FALSE
}
# reset ls_max if larger than observed time
time <- data$t - data$seeding_time - data$horizon
if (gp$ls_max > time) {
gp$ls_max <- time

time <- data$t - data$seeding_time
if (data$future_fixed > 0) {
time <- time + data$fixed_from - data$horizon
}
if (data$stationary == 1) {
time <- time - 1
}

obs_time <- data$t - data$seeding_time
if (gp$ls_max > obs_time) {
gp$ls_max <- obs_time
}

times <- seq_len(time)

rescaled_times <- (times - mean(times)) / sd(times)
gp$ls_mean <- gp$ls_mean / sd(times)
gp$ls_sd <- gp$ls_sd / sd(times)
gp$ls_min <- gp$ls_min / sd(times)
gp$ls_max <- gp$ls_max / sd(times)

# basis functions
M <- data$t - data$seeding_time
M <- ifelse(data$future_fixed == 1, M - (data$horizon - data$fixed_from), M)
M <- ceiling(M * gp$basis_prop)
M <- ceiling(time * gp$basis_prop)

# map settings to underlying gp stan requirements
gp_data <- list(
fixed = as.numeric(fixed),
M = M,
L = gp$boundary_scale,
L = gp$boundary_scale * max(rescaled_times),
ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd),
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = data$t - data$seeding_time - data$horizon,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
is.infinite(gp$matern_order), 0,
gp$matern_order == 1 / 2, 1,
gp$matern_order == 3 / 2, 2,
default = 3
)
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
gp$kernel == "matern" || gp$kernel == "ou", 2,
default = 2
),
nu = gp$matern_order,
w0 = gp$w0
)

gp_data <- c(data, gp_data)
Expand Down Expand Up @@ -591,42 +609,44 @@ create_initial_conditions <- function(data) {
out <- create_delay_inits(data)

if (data$fixed == 0) {
out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1))
out$rho <- array(rlnorm(1,
out$eta <- array(rnorm(
ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1))
out$rescaled_rho <- array(rlnorm(1,
meanlog = data$ls_meanlog,
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog * 0.1, 0.01)
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01)
))
out$rescaled_rho <- array(data.table::fcase(
out$rescaled_rho > data$ls_max, data$ls_max - 0.001,
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))

out$rho <- array(data.table::fcase(
out$rho > data$ls_max, data$ls_max - 0.001,
out$rho < data$ls_min, data$ls_min + 0.001,
default = out$rho
))

out$alpha <- array(
truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
out$rho <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd / 10
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.02))
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01))
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02))
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

Expand All @@ -641,7 +661,7 @@ create_initial_conditions <- function(data) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd * 0.1
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
Expand Down
3 changes: 1 addition & 2 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@
#' def <- estimate_infections(reported_cases,
#' generation_time = gt_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' stan = stan_opts(control = list(adapt_delta = 0.95))
seabbs marked this conversation as resolved.
Show resolved Hide resolved
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1))
#' )
#' # real time estimates
#' summary(def)
Expand Down
98 changes: 56 additions & 42 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#' @param weight_prior Logical; if TRUE (default), any priors given in `dist`
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight will be
#' applied, i.e. any parameters in `dist` will be treated as a single
#' parameters.
#' preventing the posteriors from shifting. If FALSE, no weight
#' will be applied, i.e. any parameters in `dist` will be treated as a single
#' parameters.
#' @inheritParams apply_default_tolerance
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
Expand Down Expand Up @@ -401,7 +401,7 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
Expand All @@ -411,31 +411,34 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' process length scale will be used with recommended parameters
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#'
#' @param ls_max Numeric, defaults to 60. The maximum value of the length
#' scale. Updated in [create_gp_data()] to be the length of the input data if
#' this is smaller.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter
#' of the Gaussian process kernel. Should be approximately the expected variance
#' of the logged Rt.
#'
#' @param alpha_sd Numeric, defaults to 0.05. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' @param alpha_sd Numeric, defaults to 0.01. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' the expected standard deviation of the logged Rt.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the squared exponential kernel ("se", or "matern" with
#' 'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with
#' `matern_order = 3/2` (default) or `matern_order = 5/2`, respectively), or
#' Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting
#' to the Matérn 3 over 2 kernel for a balance of smoothness and
#' discontinuities.
#' supporting the Matern kernel ("matern"), squared exponential kernel ("se"),
#' periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic
#' kernel ("periodic").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
#' to "ou", `matern_order` will be automatically set to 1/2. Only used if
#' the kernel is set to "matern".
#'
#' @param matern_type Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel
#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn
#' Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, proportion of time points to use as basis
#' @param basis_prop Numeric, the proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
#' accuracy but a faster compute time (with increasing it having the first
#' effect). In general smaller posterior length scales require a higher
Expand All @@ -446,6 +449,9 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' approximate Gaussian process. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic
#' kernel. They are only used if `kernel` is set to "periodic".
#'
#' @importFrom rlang arg_match
#' @return A `<gp_opts>` object of settings defining the Gaussian process
#' @export
Expand All @@ -455,21 +461,30 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' # add a custom length scale
#' gp_opts(ls_mean = 4)
#'
#' # use linear kernel
#' gp_opts(kernel = "periodic")
gp_opts <- function(basis_prop = 0.2,
boundary_scale = 1.5,
ls_mean = 21,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
alpha_sd = 0.05,
kernel = c("matern", "se", "ou"),
alpha_mean = 0,
alpha_sd = 0.01,
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
matern_type) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
matern_type,
w0 = 1.0) {

if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type == matern_order) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
}

if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type != matern_order) {
stop(
"Incompatible `matern_order` and `matern_type`. ",
"Use `matern_order` only."
Expand All @@ -480,20 +495,15 @@ gp_opts <- function(basis_prop = 0.2,

kernel <- arg_match(kernel)
if (kernel == "se") {
if (!missing(matern_order) && is.finite(matern_order)) {
stop("Squared exponential kernel must have matern order unset or `Inf`.")
}
matern_order <- Inf
} else if (kernel == "ou") {
if (!missing(matern_order) && matern_order != 1 / 2) {
stop("Ornstein-Uhlenbeck kernel must have matern order unset or `1 / 2`.") ## nolint: nonportable_path_linter
}
matern_order <- 1 / 2
} else if (!(is.infinite(matern_order) ||
matern_order %in% c(1 / 2, 3 / 2, 5 / 2))) {
stop(
"only the Matern kernels of order `1 / 2`, `3 / 2`, `5 / 2` or `Inf` ", ## nolint: nonportable_path_linter
"are currently supported"
} else if (
!(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
) {
warning(
"Uncommon Matern kernel order. Common orders are `1 / 2`, `3 / 2`,", # nolint
" and `5 / 2`" # nolint
)
}

Expand All @@ -504,9 +514,11 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = ls_sd,
ls_min = ls_min,
ls_max = ls_max,
alpha_mean = alpha_mean,
alpha_sd = alpha_sd,
kernel = kernel,
matern_order = matern_order
matern_order = matern_order,
w0 = w0
)

attr(gp, "class") <- c("gp_opts", class(gp))
Expand All @@ -523,8 +535,10 @@ gp_opts <- function(basis_prop = 0.2,
#' @param phi Overdispersion parameter of the reporting process, used only if
#' `familiy` is "negbin". Can be supplied either as a single numeric value
#' (fixed overdispersion) or a list with numeric elements mean (`mean`) and
#' standard deviation (`sd`) defining a normally distributed overdispersion.
#' Defaults to a list with elements `mean = 0` and `sd = 1`.
#' standard deviation (`sd`) defining a normally distributed prior.
#' Internally parametersed such that the overedispersion is one over the
#' square of this prior overdispersion. Defaults to a list with elements
#' `mean = 0` and `sd = 0.25`.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
Expand Down Expand Up @@ -563,7 +577,7 @@ gp_opts <- function(basis_prop = 0.2,
#' # Scale reported data
#' obs_opts(scale = list(mean = 0.2, sd = 0.02))
obs_opts <- function(family = c("negbin", "poisson"),
phi = list(mean = 0, sd = 1),
phi = list(mean = 0, sd = 0.25),
weight = 1,
week_effect = TRUE,
week_length = 7,
Expand Down Expand Up @@ -634,8 +648,8 @@ obs_opts <- function(family = c("negbin", "poisson"),
#' @param chains Numeric, defaults to 4. Number of MCMC chains to use.
#'
#' @param control List, defaults to empty. control parameters to pass to
#' underlying `rstan` function. By default `adapt_delta = 0.95` and
#' `max_treedepth = 15` though these settings can be overwritten.
#' underlying `rstan` function. By default `adapt_delta = 0.9` and
#' `max_treedepth = 12` though these settings can be overwritten.
#'
#' @param save_warmup Logical, defaults to FALSE. Should warmup progress be
#' saved.
Expand Down Expand Up @@ -684,7 +698,7 @@ stan_sampling_opts <- function(cores = getOption("mc.cores", 1L),
future = future,
max_execution_time = max_execution_time
)
control_def <- list(adapt_delta = 0.95, max_treedepth = 15)
control_def <- list(adapt_delta = 0.9, max_treedepth = 12)
control_def <- modifyList(control_def, control)
if (any(c("iter", "iter_sampling") %in% names(dot_args))) {
warning(
Expand Down
3 changes: 1 addition & 2 deletions R/regional_epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@
#' delays = delay_opts(example_incubation_period + example_reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.2)),
#' stan = stan_opts(
#' samples = 100, warmup = 200,
#' control = list(adapt_delta = 0.95)
#' samples = 100, warmup = 200
#' ),
#' verbose = interactive()
#' )
Expand Down
1 change: 0 additions & 1 deletion R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ simulate_infections <- function(estimates, R, initial_infections,
#' generation_time = generation_time_opts(example_generation_time),
#' delays = delay_opts(example_incubation_period + example_reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7),
#' stan = stan_opts(control = list(adapt_delta = 0.9)),
seabbs marked this conversation as resolved.
Show resolved Hide resolved
#' obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)),
#' gp = NULL, horizon = 0
#' )
Expand Down
5 changes: 4 additions & 1 deletion inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
real ls_sdlog; // sdlog for gp lengthscale prior
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
real alpha_mean; // mean of the alpha gp kernal parameter
real alpha_sd; // standard deviation of the alpha gp kernal parameter
int gp_type; // type of gp, 0 = squared exponential, 1 = 3/2 matern
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
real w0; // fundamental frequency for periodic kernel (used if gp_type = 1)
int stationary; // is underlying gaussian process first or second order
int fixed; // should a gaussian process be used
Loading
Loading