diff --git a/NEWS.md b/NEWS.md index bf51d48dd..92c69cf07 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,12 @@ - `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk. - The interface for defining delay distributions has been generalised to also cater for continuous distributions - When defining probability distributions these can now be truncated using the `tolerance` argument +- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in #741 and reviewed by @seabbs. +- Gaussian processes have been vectorised, leading to some speed gains 🚀 , and the `gp_opts()` function has gained three more options, "periodic", "ou", and "se", to specify periodic and linear kernels respectively. By @seabbs in #742 and reviewed by @jamesmbaazam. +- Prior predictive checks have been used to update the following priors: the prior on the magnitude of the Gaussian process (from HalfNormal(0, 1) to HalfNormal(0, 0.1)), and the prior on the overdispersion (from 1 / HalfNormal(0, 1)^2 to 1 / HalfNormal(0, 0.25)). In the user-facing API, this is a change in default values of the `sd` of `phi` in `obs_opts()` from 1 to 0.25. By @seabbs in #742 and reviewed by @jamesmbaazam. +- The default stan control options have been updated from `list(adapt_delta = 0.95, max_treedepth = 15)` to `list(adapt_delta = 0.9, max_treedepth = 12)` due to improved performance and to reduce the runtime of the default parameterisations. By @seabbs in #742 and reviewed by @jamesmbaazam. +- Initialisation has been simplified by sampling directly from the priors, where possible, rather than from a constrained space. By @seabbs in #742 and reviewed by @jamesmbaazam. +- Unnecessary normalisation of delay priors has been removed. By @seabbs in #742 and reviewed by @jamesmbaazam. - Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @. - Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam. - Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam. diff --git a/R/create.R b/R/create.R index d3e8bedc6..55891230e 100644 --- a/R/create.R +++ b/R/create.R @@ -346,6 +346,7 @@ create_backcalc_data <- function(backcalc = backcalc_opts()) { ) return(data) } + #' Create Gaussian Process Data #' #' @description `r lifecycle::badge("stable")` @@ -387,33 +388,50 @@ create_gp_data <- function(gp = gp_opts(), data) { } else { fixed <- FALSE } - # reset ls_max if larger than observed time - time <- data$t - data$seeding_time - data$horizon - if (gp$ls_max > time) { - gp$ls_max <- time + + time <- data$t - data$seeding_time + if (data$future_fixed > 0) { + time <- time + data$fixed_from - data$horizon + } + if (data$stationary == 1) { + time <- time - 1 } + obs_time <- data$t - data$seeding_time + if (gp$ls_max > obs_time) { + gp$ls_max <- obs_time + } + + times <- seq_len(time) + + rescaled_times <- (times - mean(times)) / sd(times) + gp$ls_mean <- gp$ls_mean / sd(times) + gp$ls_sd <- gp$ls_sd / sd(times) + gp$ls_min <- gp$ls_min / sd(times) + gp$ls_max <- gp$ls_max / sd(times) + # basis functions - M <- data$t - data$seeding_time - M <- ifelse(data$future_fixed == 1, M - (data$horizon - data$fixed_from), M) - M <- ceiling(M * gp$basis_prop) + M <- ceiling(time * gp$basis_prop) # map settings to underlying gp stan requirements gp_data <- list( fixed = as.numeric(fixed), M = M, - L = gp$boundary_scale, + L = gp$boundary_scale * max(rescaled_times), ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd), ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd), ls_min = gp$ls_min, - ls_max = data$t - data$seeding_time - data$horizon, + ls_max = gp$ls_max, + alpha_mean = gp$alpha_mean, alpha_sd = gp$alpha_sd, gp_type = data.table::fcase( - is.infinite(gp$matern_order), 0, - gp$matern_order == 1 / 2, 1, - gp$matern_order == 3 / 2, 2, - default = 3 - ) + gp$kernel == "se", 0, + gp$kernel == "periodic", 1, + gp$kernel == "matern" || gp$kernel == "ou", 2, + default = 2 + ), + nu = gp$matern_order, + w0 = gp$w0 ) gp_data <- c(data, gp_data) @@ -606,42 +624,44 @@ create_initial_conditions <- function(data) { out <- create_delay_inits(data) if (data$fixed == 0) { - out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1)) - out$rho <- array(rlnorm(1, + out$eta <- array(rnorm( + ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1)) + out$rescaled_rho <- array(rlnorm(1, meanlog = data$ls_meanlog, - sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog * 0.1, 0.01) + sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01) + )) + out$rescaled_rho <- array(data.table::fcase( + out$rescaled_rho > data$ls_max, data$ls_max - 0.001, + out$rescaled_rho < data$ls_min, data$ls_min + 0.001, + default = out$rescaled_rho )) - - out$rho <- array(data.table::fcase( - out$rho > data$ls_max, data$ls_max - 0.001, - out$rho < data$ls_min, data$ls_min + 0.001, - default = out$rho - )) out$alpha <- array( - truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd) + truncnorm::rtruncnorm( + 1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd + ) ) } else { out$eta <- array(numeric(0)) - out$rho <- array(numeric(0)) + out$rescaled_rho <- array(numeric(0)) out$alpha <- array(numeric(0)) } if (data$model_type == 1) { out$rep_phi <- array( truncnorm::rtruncnorm( 1, - a = 0, mean = data$phi_mean, sd = data$phi_sd / 10 + a = 0, mean = data$phi_mean, sd = data$phi_sd ) ) } if (data$estimate_r == 1) { - out$initial_infections <- array(rnorm(1, data$prior_infections, 0.02)) + out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2)) if (data$seeding_time > 1) { - out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01)) + out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02)) } out$log_R <- array(rnorm( n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd), - sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1 + sd = convert_to_logsd(data$r_mean, data$r_sd) )) } @@ -656,7 +676,7 @@ create_initial_conditions <- function(data) { out$frac_obs <- array(truncnorm::rtruncnorm(1, a = 0, b = 1, mean = data$obs_scale_mean, - sd = data$obs_scale_sd * 0.1 + sd = data$obs_scale_sd )) } else { out$frac_obs <- array(numeric(0)) diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 35e36946b..fde83e449 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -101,8 +101,7 @@ #' def <- estimate_infections(reported_cases, #' generation_time = gt_opts(generation_time), #' delays = delay_opts(incubation_period + reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)), -#' stan = stan_opts(control = list(adapt_delta = 0.95)) +#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)) #' ) #' # real time estimates #' summary(def) diff --git a/R/opts.R b/R/opts.R index cc1dc0853..fe352100e 100644 --- a/R/opts.R +++ b/R/opts.R @@ -15,9 +15,9 @@ #' @param weight_prior Logical; if TRUE (default), any priors given in `dist` #' will be weighted by the number of observation data points, in doing so #' approximately placing an independent prior at each time step and usually -#' preventing the posteriors from shifting. If FALSE, no weight will be -#' applied, i.e. any parameters in `dist` will be treated as a single -#' parameters. +#' preventing the posteriors from shifting. If FALSE, no weight +#' will be applied, i.e. any parameters in `dist` will be treated as a single +#' parameters. #' @inheritParams apply_default_tolerance #' @return A `` object summarising the input delay #' distributions. @@ -401,7 +401,7 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' #' @description `r lifecycle::badge("stable")` #' Defines a list specifying the structure of the approximate Gaussian -#' process. Custom settings can be supplied which override the defaults. +#' process. Custom settings can be supplied which override the defaults. #' #' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal #' length scale. @@ -411,31 +411,34 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' process length scale will be used with recommended parameters #' \code{inv_gamma(1.499007, 0.057277 * ls_max)}. #' +#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale. +#' #' @param ls_max Numeric, defaults to 60. The maximum value of the length #' scale. Updated in [create_gp_data()] to be the length of the input data if #' this is smaller. #' -#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale. +#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter +#' of the Gaussian process kernel. Should be approximately the expected variance +#' of the logged Rt. #' -#' @param alpha_sd Numeric, defaults to 0.05. The standard deviation of the -#' magnitude parameter of the Gaussian process kernel. Should be approximately +#' @param alpha_sd Numeric, defaults to 0.01. The standard deviation of the +#' magnitude parameter of the Gaussian process kernel. Should be approximately #' the expected standard deviation of the logged Rt. #' #' @param kernel Character string, the type of kernel required. Currently -#' supporting the squared exponential kernel ("se", or "matern" with -#' 'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with -#' `matern_order = 3/2` (default) or `matern_order = 5/2`, respectively), or -#' Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting -#' to the Matérn 3 over 2 kernel for a balance of smoothness and -#' discontinuities. +#' supporting the Matern kernel ("matern"), squared exponential kernel ("se"), +#' periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic +#' kernel ("periodic"). #' #' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use. -#' Currently the orders 1/2, 3/2, 5/2 and Inf are supported. +#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set +#' to "ou", `matern_order` will be automatically set to 1/2. Only used if +#' the kernel is set to "matern". #' -#' @param matern_type Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel -#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported. +#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn +#' Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported. #' -#' @param basis_prop Numeric, proportion of time points to use as basis +#' @param basis_prop Numeric, the proportion of time points to use as basis #' functions. Defaults to 0.2. Decreasing this value results in a decrease in #' accuracy but a faster compute time (with increasing it having the first #' effect). In general smaller posterior length scales require a higher @@ -446,6 +449,9 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' approximate Gaussian process. See (Riutort-Mayol et al. 2020 #' ) for advice on updating this default. #' +#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic +#' kernel. They are only used if `kernel` is set to "periodic". +#' #' @importFrom rlang arg_match #' @return A `` object of settings defining the Gaussian process #' @export @@ -455,21 +461,30 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' #' # add a custom length scale #' gp_opts(ls_mean = 4) +#' +#' # use linear kernel +#' gp_opts(kernel = "periodic") gp_opts <- function(basis_prop = 0.2, boundary_scale = 1.5, ls_mean = 21, ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_sd = 0.05, - kernel = c("matern", "se", "ou"), + alpha_mean = 0, + alpha_sd = 0.01, + kernel = c("matern", "se", "ou", "periodic"), matern_order = 3 / 2, - matern_type) { - lifecycle::deprecate_warn( - "1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)" - ) + matern_type, + w0 = 1.0) { + if (!missing(matern_type)) { - if (!missing(matern_order) && matern_type == matern_order) { + lifecycle::deprecate_warn( + "1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)" + ) + } + + if (!missing(matern_type)) { + if (!missing(matern_order) && matern_type != matern_order) { stop( "Incompatible `matern_order` and `matern_type`. ", "Use `matern_order` only." @@ -480,20 +495,15 @@ gp_opts <- function(basis_prop = 0.2, kernel <- arg_match(kernel) if (kernel == "se") { - if (!missing(matern_order) && is.finite(matern_order)) { - stop("Squared exponential kernel must have matern order unset or `Inf`.") - } matern_order <- Inf } else if (kernel == "ou") { - if (!missing(matern_order) && matern_order != 1 / 2) { - stop("Ornstein-Uhlenbeck kernel must have matern order unset or `1 / 2`.") ## nolint: nonportable_path_linter - } matern_order <- 1 / 2 - } else if (!(is.infinite(matern_order) || - matern_order %in% c(1 / 2, 3 / 2, 5 / 2))) { - stop( - "only the Matern kernels of order `1 / 2`, `3 / 2`, `5 / 2` or `Inf` ", ## nolint: nonportable_path_linter - "are currently supported" + } else if ( + !(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2)) + ) { + warning( + "Uncommon Matern kernel order. Common orders are `1 / 2`, `3 / 2`,", # nolint + " and `5 / 2`" # nolint ) } @@ -504,9 +514,11 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = ls_sd, ls_min = ls_min, ls_max = ls_max, + alpha_mean = alpha_mean, alpha_sd = alpha_sd, kernel = kernel, - matern_order = matern_order + matern_order = matern_order, + w0 = w0 ) attr(gp, "class") <- c("gp_opts", class(gp)) @@ -523,8 +535,10 @@ gp_opts <- function(basis_prop = 0.2, #' @param phi Overdispersion parameter of the reporting process, used only if #' `familiy` is "negbin". Can be supplied either as a single numeric value #' (fixed overdispersion) or a list with numeric elements mean (`mean`) and -#' standard deviation (`sd`) defining a normally distributed overdispersion. -#' Defaults to a list with elements `mean = 0` and `sd = 1`. +#' standard deviation (`sd`) defining a normally distributed prior. +#' Internally parametersed such that the overedispersion is one over the +#' square of this prior overdispersion. Defaults to a list with elements +#' `mean = 0` and `sd = 0.25`. #' @param weight Numeric, defaults to 1. Weight to give the observed data in the #' log density. #' @param week_effect Logical defaulting to `TRUE`. Should a day of the week @@ -563,7 +577,7 @@ gp_opts <- function(basis_prop = 0.2, #' # Scale reported data #' obs_opts(scale = list(mean = 0.2, sd = 0.02)) obs_opts <- function(family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 1), + phi = list(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, @@ -634,8 +648,8 @@ obs_opts <- function(family = c("negbin", "poisson"), #' @param chains Numeric, defaults to 4. Number of MCMC chains to use. #' #' @param control List, defaults to empty. control parameters to pass to -#' underlying `rstan` function. By default `adapt_delta = 0.95` and -#' `max_treedepth = 15` though these settings can be overwritten. +#' underlying `rstan` function. By default `adapt_delta = 0.9` and +#' `max_treedepth = 12` though these settings can be overwritten. #' #' @param save_warmup Logical, defaults to FALSE. Should warmup progress be #' saved. @@ -684,7 +698,7 @@ stan_sampling_opts <- function(cores = getOption("mc.cores", 1L), future = future, max_execution_time = max_execution_time ) - control_def <- list(adapt_delta = 0.95, max_treedepth = 15) + control_def <- list(adapt_delta = 0.9, max_treedepth = 12) control_def <- modifyList(control_def, control) if (any(c("iter", "iter_sampling") %in% names(dot_args))) { warning( diff --git a/R/regional_epinow.R b/R/regional_epinow.R index 80408d0ca..da06b12a9 100644 --- a/R/regional_epinow.R +++ b/R/regional_epinow.R @@ -83,8 +83,7 @@ #' delays = delay_opts(example_incubation_period + example_reporting_delay), #' rt = rt_opts(prior = list(mean = 2, sd = 0.2)), #' stan = stan_opts( -#' samples = 100, warmup = 200, -#' control = list(adapt_delta = 0.95) +#' samples = 100, warmup = 200 #' ), #' verbose = interactive() #' ) diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 59ddd950c..4b0974ed8 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -273,7 +273,6 @@ simulate_infections <- function(estimates, R, initial_infections, #' generation_time = generation_time_opts(example_generation_time), #' delays = delay_opts(example_incubation_period + example_reporting_delay), #' rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), -#' stan = stan_opts(control = list(adapt_delta = 0.9)), #' obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), #' gp = NULL, horizon = 0 #' ) diff --git a/inst/stan/data/gaussian_process.stan b/inst/stan/data/gaussian_process.stan index 82119a127..8154ffdfe 100644 --- a/inst/stan/data/gaussian_process.stan +++ b/inst/stan/data/gaussian_process.stan @@ -4,7 +4,10 @@ real ls_sdlog; // sdlog for gp lengthscale prior real ls_min; // Lower bound for the lengthscale real ls_max; // Upper bound for the lengthscale + real alpha_mean; // mean of the alpha gp kernal parameter real alpha_sd; // standard deviation of the alpha gp kernal parameter - int gp_type; // type of gp, 0 = squared exponential, 1 = 3/2 matern + int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern + real nu; // smoothness parameter for Matern kernel (used if gp_type = 2) + real w0; // fundamental frequency for periodic kernel (used if gp_type = 1) int stationary; // is underlying gaussian process first or second order int fixed; // should a gaussian process be used diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 985cfa9c2..303b6b0e4 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -9,7 +9,6 @@ functions { #include functions/generated_quantities.stan } - data { #include data/observations.stan #include data/delays.stan @@ -19,7 +18,7 @@ data { #include data/observation_model.stan } -transformed data{ +transformed data { // observations int ot = t - seeding_time - horizon; // observed time int ot_h = ot + horizon; // observed time + forecast horizon @@ -27,7 +26,7 @@ transformed data{ int noise_terms = setup_noise( ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from ); - matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function + matrix[noise_terms, gp_type == 1 ? 2*M : M] PHI = setup_gp(M, L, noise_terms, gp_type == 1, w0); // basis function // Rt real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); @@ -41,38 +40,41 @@ transformed data{ } } -parameters{ +parameters { // gaussian process - array[fixed ? 0 : 1] real rho; // length scale of noise GP - array[fixed ? 0 : 1] real alpha; // scale of of noise GP - vector[fixed ? 0 : M] eta; // unconstrained noise + array[fixed ? 0 : 1] real rescaled_rho; // length scale of noise GP + array[fixed ? 0 : 1] real alpha; // scale of noise GP + vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise // Rt vector[estimate_r] log_R; // baseline reproduction number estimate (log) - array[estimate_r] real initial_infections ; // seed infections + array[estimate_r] real initial_infections; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect vector[bp_n] bp_effects; // Rt breakpoint effects // observation model - vector[delay_params_length] delay_params; // delay parameters - simplex[week_effect] day_of_week_simplex;// day of week reporting effect - array[obs_scale_sd > 0 ? 1 : 0] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process + simplex[week_effect] day_of_week_simplex; // day of week reporting effect + array[obs_scale_sd > 0 ? 1 : 0] real frac_obs; // fraction of cases that are ultimately observed + array[model_type] real rep_phi; // overdispersion of the reporting process } transformed parameters { - vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process + vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process vector[estimate_r > 0 ? ot_h : 0] R; // reproduction number vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf; + // GP in noise - spectral densities profile("update gp") { if (!fixed) { - noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type); + noise = update_gp( + PHI, M, L, alpha[1], rescaled_rho, eta, gp_type, nu + ); } } + // Estimate latent infections if (estimate_r) { profile("gt") { @@ -102,6 +104,7 @@ transformed parameters { ); } } + // convolve from latent infections to mean of observations if (delay_id) { vector[delay_type_max[delay_id] + 1] delay_rev_pmf; @@ -119,12 +122,14 @@ transformed parameters { } else { reports = infections[(seeding_time + 1):t]; } + // weekly reporting effect if (week_effect > 1) { profile("day of the week") { reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex); } } + // scaling of reported cases by fraction observed if (obs_scale) { profile("scale") { @@ -133,6 +138,7 @@ transformed parameters { ); } } + // truncate near time cases to observed reports if (trunc_id) { vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf; @@ -156,18 +162,21 @@ model { // priors for noise GP if (!fixed) { profile("gp lp") { - gaussian_process_lp( - rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min, ls_max, alpha_sd - ); + gaussian_process_lp(alpha[1], eta, alpha_mean, alpha_sd); + if (gp_type != 3) { + lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max); + } } } - // penalised priors for delay distributions + + // penalized priors for delay distributions profile("delays lp") { delays_lp( delay_params, delay_params_mean, delay_params_sd, delay_params_groups, delay_dist, delay_weight ); } + if (estimate_r) { // priors on Rt profile("rt lp") { @@ -177,12 +186,14 @@ model { ); } } + // prior observation scaling if (obs_scale_sd > 0) { profile("scale lp") { frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; } } + // observed reports from mean of reports (update likelihood) if (likelihood) { profile("report lp") { @@ -196,11 +207,18 @@ model { generated quantities { array[ot_h] int imputed_reports; - vector[estimate_r > 0 ? 0: ot_h] gen_R; + vector[estimate_r > 0 ? 0 : ot_h] gen_R; vector[ot_h - 1] r; vector[return_likelihood ? ot : 0] log_lik; + vector[fixed ? 0 : 1] rho; + profile("generated quantities") { - if (estimate_r == 0){ + if (!fixed && gp_type != 3) { + vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms); + rho[1] = rescaled_rho[1] * sd(x); + } + + if (estimate_r == 0) { // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower @@ -216,10 +234,13 @@ generated quantities { infections, seeding_time, sampled_gt_rev_pmf, rt_half_window ); } + // estimate growth from infections r = calculate_growth(infections, seeding_time + 1); + // simulate reported cases imputed_reports = report_rng(reports, rep_phi, model_type); + // log likelihood of model if (return_likelihood) { log_lik = report_log_lik( diff --git a/inst/stan/functions/delays.stan b/inst/stan/functions/delays.stan index 784ad36c5..c4c89a09a 100644 --- a/inst/stan/functions/delays.stan +++ b/inst/stan/functions/delays.stan @@ -89,15 +89,10 @@ void delays_lp(vector delay_params, int end = delay_params_groups[d + 1] - 1; for (s in start:end) { if (delay_params_sd[s] > 0) { - // uncertain mean - target += normal_lpdf( - delay_params[s] | delay_params_mean[s], delay_params_sd[s] - ) * weight[d]; - // if a distribution with postive support only truncate the prior - if (delay_dist[d] == 1) { - target += -normal_lccdf( - 0 | delay_params_mean[s], delay_params_sd[s] - ) * weight[d]; + if (weight[d] > 1) { + target += weight[d] * normal_lpdf(delay_params[s] | delay_params_mean[s], delay_params_sd[s]); + }else { + delay_params[s] ~ normal(delay_params_mean[s], delay_params_sd[s]); } } } diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index 281f0095c..9c4b50674 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -1,111 +1,181 @@ -// eigenvalues for approximate hilbert space gp -// see here for details: https://arxiv.org/pdf/2004.11408.pdf -real lambda(real L, int m) { - real lam; - lam = ((m * pi())/(2 * L))^2; - return lam; +/** + * These functions implemente approximuate Gaussian processes for Stan using + * Hilbert space methods. The functions are based on the following: + * - https://avehtari.github.io/casestudies/Motorcycle/motorcycle_gpcourse.html#4_Heteroskedastic_GP_with_Hilbert_basis_functions + * - https://arxiv.org/abs/2004.11408 + */ + +/** + * Spectral density for Exponentiated Quadratic kernel + * + * @param alpha Scaling parameter + * @param rho Length scale parameter + * @param L Length of the interval + * @param M Number of basis functions + * @return A vector of spectral densities + */ +vector diagSPD_EQ(real alpha, real rho, real L, int M) { + vector[M] indices = linspaced_vector(M, 1, M); + real factor = alpha * sqrt(sqrt(2 * pi()) * rho); + real exponent = -0.25 * (rho * pi() / 2 / L)^2; + return factor * exp(exponent * square(indices)); } -// eigenfunction for approximate hilbert space gp -// see here for details: https://arxiv.org/pdf/2004.11408.pdf -vector phi(real L, int m, vector x) { - vector[rows(x)] fi; - fi = 1/sqrt(L) * sin(m * pi()/(2 * L) * (x + L)); - return fi; +/** + * Spectral density for Matern kernel + * + * @param nu Smoothness parameter (1/2, 3/2, or 5/2) + * @param alpha Scaling parameter + * @param rho Length scale parameter + * @param L Length of the interval + * @param M Number of basis functions + * @return A vector of spectral densities + */ +vector diagSPD_Matern(real nu, real alpha, real rho, real L, int M) { + vector[M] indices = linspaced_vector(M, 1, M); + real factor = 2 * alpha * pow(sqrt(2 * nu) / rho, nu); + vector[M] denom = (sqrt(2 * nu) / rho)^2 + pow((pi() / 2 / L) * indices, nu + 0.5); + return factor * inv(denom); } -// spectral density of the exponential quadratic kernal -real spd_se(real alpha, real rho, real w) { - real S; - // S = (alpha^2) * sqrt(2 * pi()) * rho * exp(-0.5 * (rho^2) * (w^2)); - S = 2.506628 * alpha * rho * exp(-0.5 * (rho^2) * (w^2)); - return S; +/** + * Spectral density for periodic kernel + * + * @param alpha Scaling parameter + * @param rho Length scale parameter + * @param M Number of basis functions + * @return A vector of spectral densities + */ +vector diagSPD_Periodic(real alpha, real rho, int M) { + real a = inv_square(rho); + vector[M] indices = linspaced_vector(M, 1, M); + vector[M] q = exp(log(alpha) + 0.5 * (log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a)))); + return append_row(q, q); } -// spectral density of the Ornstein-Uhlenbeck kernal -real spd_ou(real alpha, real rho, real w) { - real S; - S = 2 * alpha * rho / (1 + rho^2 * w^2); - return S; -} -// spectral density of the Matern 3/2 kernel -real spd_matern32(real alpha, real rho, real w) { - real S; - // S = 4 * alpha^2 * (sqrt(3) / rho)^3 * 1 / ((sqrt(3) / rho)^2 + w^2)^2; - S = 20.78461 * alpha / (rho^3 * (3 / rho^2 + w^2)^2); - return S; +/** + * Basis functions for Gaussian Process + * + * @param N Number of data points + * @param M Number of basis functions + * @param L Length of the interval + * @param x Vector of input data + * @return A matrix of basis functions + */ +matrix PHI(int N, int M, real L, vector x) { + matrix[N, M] phi = sin(diag_post_multiply(rep_matrix(pi() / (2 * L) * (x + L), M), linspaced_vector(M, 1, M))) / sqrt(L); + return phi; } -real spd_matern52(real alpha, real rho, real w) { - real S; - // S = 16/3 * alpha^2 * (sqrt(5) / rho)^5 * 1 / ((sqrt(5) / rho)^2 + w^2)^3 - S = 298.1424 * alpha / (rho^5 * (5 / rho^2 + w^2)^3); - return S; +/** + * Basis functions for periodic Gaussian Process + * + * @param N Number of data points + * @param M Number of basis functions + * @param w0 Fundamental frequency + * @param x Vector of input data + * @return A matrix of basis functions + */ +matrix PHI_periodic(int N, int M, real w0, vector x) { + matrix[N, M] mw0x = diag_post_multiply(rep_matrix(w0 * x, M), linspaced_vector(M, 1, M)); + return append_col(cos(mw0x), sin(mw0x)); } -// setup gaussian process noise dimensions +/** + * Setup Gaussian process noise dimensions + * + * @param ot_h Observation time horizon + * @param t Total time points + * @param horizon Forecast horizon + * @param estimate_r Indicator if estimating r + * @param stationary Indicator if stationary + * @param future_fixed Indicator if future is fixed + * @param fixed_from Fixed point from + * @return Number of noise terms + */ int setup_noise(int ot_h, int t, int horizon, int estimate_r, int stationary, int future_fixed, int fixed_from) { int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t; - int noise_terms = - future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; - return(noise_terms); + int noise_terms = future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; + return noise_terms; } -// setup approximate gaussian process -matrix setup_gp(int M, real L, int dimension) { - vector[dimension] time; - matrix[dimension, M] PHI; - real half_dim = dimension / 2.0; - for (s in 1:dimension) { - time[s] = (s - half_dim) / half_dim; - } - for (m in 1:M){ - PHI[,m] = phi(L, m, time); +/** + * Setup approximate Gaussian process + * + * @param M Number of basis functions + * @param L Length of the interval + * @param dimension Dimension of the process + * @param is_periodic Indicator if the process is periodic + * @param w0 Fundamental frequency for periodic process + * @return A matrix of basis functions + */ +matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) { + vector[dimension] x = linspaced_vector(dimension, 1, dimension); + x = (x - mean(x)) / sd(x); + if (is_periodic) { + return PHI_periodic(dimension, M, w0, x); + } else { + return PHI(dimension, M, L, x); } - return(PHI); } -// update gaussian process using spectral densities +/** + * Update Gaussian process using spectral densities + * + * @param PHI Basis functions matrix + * @param M Number of basis functions + * @param L Length of the interval + * @param alpha Scaling parameter + * @param rho Length scale parameter + * @param eta Vector of noise terms + * @param type Type of kernel (0: SE, 1: Periodic, 2: Matern) + * @param nu Smoothness parameter for Matern kernel + * @return A vector of updated noise terms + */ vector update_gp(matrix PHI, int M, real L, real alpha, - real rho, vector eta, int type) { - vector[M] diagSPD; // spectral density - vector[M] SPD_eta; // spectral density * noise - int noise_terms = rows(PHI); - vector[noise_terms] noise = rep_vector(1e-6, noise_terms); - real unit_rho = rho / noise_terms; + array[] real rho, vector eta, int type, real nu) { + vector[type == 1 ? 2 * M : M] diagSPD; // spectral density + // GP in noise - spectral densities if (type == 0) { - for(m in 1:M){ - diagSPD[m] = sqrt(spd_se(alpha, unit_rho, sqrt(lambda(L, m)))); - } + diagSPD = diagSPD_EQ(alpha, rho[1], L, M); } else if (type == 1) { - for(m in 1:M){ - diagSPD[m] = sqrt(spd_ou(alpha, unit_rho, sqrt(lambda(L, m)))); - } + diagSPD = diagSPD_Periodic(alpha, rho[1], M); } else if (type == 2) { - for(m in 1:M){ - diagSPD[m] = sqrt(spd_matern32(alpha, unit_rho, sqrt(lambda(L, m)))); - } - } else if (type == 3) { - for(m in 1:M){ - diagSPD[m] = sqrt(spd_matern52(alpha, unit_rho, sqrt(lambda(L, m)))); - } + diagSPD = diagSPD_Matern(nu, alpha, rho[1], L, M); } - SPD_eta = diagSPD .* eta; - noise = noise + PHI[,] * SPD_eta; - return(noise); + return PHI * (diagSPD .* eta); } -// priors for gaussian process -void gaussian_process_lp(real rho, real alpha, vector eta, - real ls_meanlog, real ls_sdlog, - real ls_min, real ls_max, real alpha_sd) { +/** + * Prior for Gaussian process length scale + * + * @param rho Length scale parameter + * @param ls_meanlog Mean of the log of the length scale + * @param ls_sdlog Standard deviation of the log of the length scale + * @param ls_min Minimum length scale + * @param ls_max Maximum length scale + */ +void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog, + real ls_min, real ls_max) { if (ls_sdlog > 0) { rho ~ lognormal(ls_meanlog, ls_sdlog) T[ls_min, ls_max]; } else { rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max]; } - alpha ~ normal(0, alpha_sd) T[0,]; +} + +/** + * Priors for Gaussian process (excluding length scale) + * + * @param alpha Scaling parameter + * @param eta Vector of noise terms + * @param alpha_mean Mean of alpha + * @param alpha_sd Standard deviation of alpha + */ +void gaussian_process_lp(real alpha, vector eta, real alpha_mean, + real alpha_sd) { + alpha ~ normal(alpha_mean, alpha_sd) T[0,]; eta ~ std_normal(); } diff --git a/man/delay_opts.Rd b/man/delay_opts.Rd index 5f56da233..5df09f408 100644 --- a/man/delay_opts.Rd +++ b/man/delay_opts.Rd @@ -27,8 +27,8 @@ constrained by having a maximum or tolerance this is ignored.} \item{weight_prior}{Logical; if TRUE (default), any priors given in \code{dist} will be weighted by the number of observation data points, in doing so approximately placing an independent prior at each time step and usually -preventing the posteriors from shifting. If FALSE, no weight will be -applied, i.e. any parameters in \code{dist} will be treated as a single +preventing the posteriors from shifting. If FALSE, no weight +will be applied, i.e. any parameters in \code{dist} will be treated as a single parameters.} } \value{ diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index 080db1bbc..f77981a07 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -144,8 +144,7 @@ reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10) def <- estimate_infections(reported_cases, generation_time = gt_opts(generation_time), delays = delay_opts(incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)), - stan = stan_opts(control = list(adapt_delta = 0.95)) + rt = rt_opts(prior = list(mean = 2, sd = 0.1)) ) # real time estimates summary(def) diff --git a/man/forecast_infections.Rd b/man/forecast_infections.Rd index 822896ded..04e617011 100644 --- a/man/forecast_infections.Rd +++ b/man/forecast_infections.Rd @@ -66,7 +66,6 @@ est <- estimate_infections(reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), - stan = stan_opts(control = list(adapt_delta = 0.9)), obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), gp = NULL, horizon = 0 ) diff --git a/man/generation_time_opts.Rd b/man/generation_time_opts.Rd index f0a286ea0..35350725a 100644 --- a/man/generation_time_opts.Rd +++ b/man/generation_time_opts.Rd @@ -48,8 +48,8 @@ constrained by having a maximum or tolerance this is ignored.} \item{weight_prior}{Logical; if TRUE (default), any priors given in \code{dist} will be weighted by the number of observation data points, in doing so approximately placing an independent prior at each time step and usually -preventing the posteriors from shifting. If FALSE, no weight will be -applied, i.e. any parameters in \code{dist} will be treated as a single +preventing the posteriors from shifting. If FALSE, no weight +will be applied, i.e. any parameters in \code{dist} will be treated as a single parameters.} } \value{ diff --git a/man/gp_opts.Rd b/man/gp_opts.Rd index ee90c50b3..c9bbf2d9e 100644 --- a/man/gp_opts.Rd +++ b/man/gp_opts.Rd @@ -11,14 +11,16 @@ gp_opts( ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_sd = 0.05, - kernel = c("matern", "se", "ou"), + alpha_mean = 0, + alpha_sd = 0.01, + kernel = c("matern", "se", "ou", "periodic"), matern_order = 3/2, - matern_type + matern_type, + w0 = 1 ) } \arguments{ -\item{basis_prop}{Numeric, proportion of time points to use as basis +\item{basis_prop}{Numeric, the proportion of time points to use as basis functions. Defaults to 0.2. Decreasing this value results in a decrease in accuracy but a faster compute time (with increasing it having the first effect). In general smaller posterior length scales require a higher @@ -43,23 +45,29 @@ process length scale will be used with recommended parameters scale. Updated in \code{\link[=create_gp_data]{create_gp_data()}} to be the length of the input data if this is smaller.} -\item{alpha_sd}{Numeric, defaults to 0.05. The standard deviation of the +\item{alpha_mean}{Numeric, defaults to 0. The mean of the magnitude parameter +of the Gaussian process kernel. Should be approximately the expected variance +of the logged Rt.} + +\item{alpha_sd}{Numeric, defaults to 0.01. The standard deviation of the magnitude parameter of the Gaussian process kernel. Should be approximately the expected standard deviation of the logged Rt.} \item{kernel}{Character string, the type of kernel required. Currently -supporting the squared exponential kernel ("se", or "matern" with -'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with -\code{matern_order = 3/2} (default) or \code{matern_order = 5/2}, respectively), or -Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting -to the Matérn 3 over 2 kernel for a balance of smoothness and -discontinuities.} +supporting the Matern kernel ("matern"), squared exponential kernel ("se"), +periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic +kernel ("periodic").} \item{matern_order}{Numeric, defaults to 3/2. Order of Matérn Kernel to use. -Currently the orders 1/2, 3/2, 5/2 and Inf are supported.} +Common choices are 1/2, 3/2, and 5/2. If \code{kernel} is set +to "ou", \code{matern_order} will be automatically set to 1/2. Only used if +the kernel is set to "matern".} + +\item{matern_type}{Deprecated; Numeric, defaults to 3/2. Order of Matérn +Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.} -\item{matern_type}{Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel -to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.} +\item{w0}{Numeric, defaults to 1.0. Fundamental frequency for periodic +kernel. They are only used if \code{kernel} is set to "periodic".} } \value{ A \verb{} object of settings defining the Gaussian process @@ -75,4 +83,7 @@ gp_opts() # add a custom length scale gp_opts(ls_mean = 4) + +# use linear kernel +gp_opts(kernel = "periodic") } diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index c803b35a7..66a3b10dc 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -6,7 +6,7 @@ \usage{ obs_opts( family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 1), + phi = list(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, @@ -23,8 +23,10 @@ Negative binomial ("negbin"), the default, and Poisson.} \item{phi}{Overdispersion parameter of the reporting process, used only if \code{familiy} is "negbin". Can be supplied either as a single numeric value (fixed overdispersion) or a list with numeric elements mean (\code{mean}) and -standard deviation (\code{sd}) defining a normally distributed overdispersion. -Defaults to a list with elements \code{mean = 0} and \code{sd = 1}.} +standard deviation (\code{sd}) defining a normally distributed prior. +Internally parametersed such that the overedispersion is one over the +square of this prior overdispersion. Defaults to a list with elements +\code{mean = 0} and \code{sd = 0.25}.} \item{weight}{Numeric, defaults to 1. Weight to give the observed data in the log density.} diff --git a/man/regional_epinow.Rd b/man/regional_epinow.Rd index ea6134a86..64b150922 100644 --- a/man/regional_epinow.Rd +++ b/man/regional_epinow.Rd @@ -156,8 +156,7 @@ def <- regional_epinow( delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts(prior = list(mean = 2, sd = 0.2)), stan = stan_opts( - samples = 100, warmup = 200, - control = list(adapt_delta = 0.95) + samples = 100, warmup = 200 ), verbose = interactive() ) diff --git a/man/rstan_sampling_opts.Rd b/man/rstan_sampling_opts.Rd index a1889d609..8847db6db 100644 --- a/man/rstan_sampling_opts.Rd +++ b/man/rstan_sampling_opts.Rd @@ -31,8 +31,8 @@ When using multiple chains iterations per chain is samples / chains.} \item{chains}{Numeric, defaults to 4. Number of MCMC chains to use.} \item{control}{List, defaults to empty. control parameters to pass to -underlying \code{rstan} function. By default \code{adapt_delta = 0.95} and -\code{max_treedepth = 15} though these settings can be overwritten.} +underlying \code{rstan} function. By default \code{adapt_delta = 0.9} and +\code{max_treedepth = 12} though these settings can be overwritten.} \item{save_warmup}{Logical, defaults to FALSE. Should warmup progress be saved.} diff --git a/man/stan_sampling_opts.Rd b/man/stan_sampling_opts.Rd index e1b89d466..cc09c3dde 100644 --- a/man/stan_sampling_opts.Rd +++ b/man/stan_sampling_opts.Rd @@ -32,8 +32,8 @@ When using multiple chains iterations per chain is samples / chains.} \item{chains}{Numeric, defaults to 4. Number of MCMC chains to use.} \item{control}{List, defaults to empty. control parameters to pass to -underlying \code{rstan} function. By default \code{adapt_delta = 0.95} and -\code{max_treedepth = 15} though these settings can be overwritten.} +underlying \code{rstan} function. By default \code{adapt_delta = 0.9} and +\code{max_treedepth = 12} though these settings can be overwritten.} \item{save_warmup}{Logical, defaults to FALSE. Should warmup progress be saved.} diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index e5e31d564..d6614247f 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -3,7 +3,8 @@ library("lifecycle") if (identical(Sys.getenv("NOT_CRAN"), "true")) { files <- c( - "convolve.stan", "pmfs.stan", "observation_model.stan", "secondary.stan", + "convolve.stan", "gaussian_process.stan", "pmfs.stan", + "observation_model.stan", "secondary.stan", "rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan" ) if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) { diff --git a/tests/testthat/test-create_gp_data.R b/tests/testthat/test-create_gp_data.R new file mode 100644 index 000000000..b1f7cc765 --- /dev/null +++ b/tests/testthat/test-create_gp_data.R @@ -0,0 +1,51 @@ +test_that("create_gp_data returns correct default values when GP is disabled", { + data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0) + restricted_time <- 30 - 7 - 1 + times <- seq_len(restricted_time) + gp_data <- create_gp_data(NULL, data) + expect_equal(gp_data$fixed, 1) + expect_equal(gp_data$stationary, 1) + expect_equal(gp_data$M, 5) # (30 - 7) * 0.2 + expect_equal(gp_data$L, 2.43, tolerance = 0.01) + expect_equal(gp_data$ls_meanlog, convert_to_logmean(21 / sd(times), 7 / sd(times))) + expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7)) + expect_equal(gp_data$ls_min, 0) + expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01) + expect_equal(gp_data$alpha_sd, 0.01) + expect_equal(gp_data$gp_type, 2) # Default to Matern + expect_equal(gp_data$nu, 3 / 2) + expect_equal(gp_data$w0, 1.0) +}) + +test_that("create_gp_data sets correct gp_type and nu for different kernels", { + data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0) + + gp <- gp_opts(kernel = "se") + gp_data <- create_gp_data(gp, data) + expect_equal(gp_data$gp_type, 0) + expect_equal(gp_data$nu, Inf) + + gp <- gp_opts(kernel = "periodic") + gp_data <- create_gp_data(gp, data) + expect_equal(gp_data$gp_type, 1) + expect_equal(gp_data$nu, 3 / 2) # Default Matern order + expect_equal(gp_data$w0, 1.0) + + gp <- gp_opts(kernel = "ou") + gp_data <- create_gp_data(gp, data) + expect_equal(gp_data$gp_type, 2) + expect_equal(gp_data$nu, 1 / 2) +}) + +test_that("create_gp_data correctly adjusts ls_max", { + data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0) + gp <- gp_opts(ls_max = 50) + gp_data <- create_gp_data(gp, data) + expect_equal(gp_data$ls_max, 3.39, tolerance = 0.01) # 30 - 7 - 7 +}) + +test_that("create_gp_data correctly handles future_fixed", { + data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 1, fixed_from = 2, stationary = 0) + gp_data <- create_gp_data(gp_opts(), data) + expect_equal(gp_data$M, 4) +}) diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index a5af77db8..b48d33be5 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -38,6 +38,20 @@ test_that("estimate_infections successfully returns estimates using default sett test_estimate_infections(reported_cases) }) +test_that("estimate_infections successfully returns estimates using a Matern 5/2 kernel", { + skip_on_cran() + test_estimate_infections( + reported_cases, gp = gp_opts(kernel = "matern", matern_order = 5 / 2) + ) +}) + +test_that("estimate_infections successfully returns estimates using a periodic kernel", { + skip_on_cran() + test_estimate_infections( + reported_cases, gp = gp_opts(kernel = "periodic") + ) +}) + test_that("estimate_infections successfully returns estimates when passed NA values", { skip_on_cran() reported_cases_na <- data.table::copy(reported_cases) @@ -85,7 +99,6 @@ test_that("estimate_infections successfully returns estimates using a single bre ) }) - test_that("estimate_infections successfully returns estimates using a random walk", { skip_on_cran() test_estimate_infections(reported_cases, gp = NULL, rt = rt_opts(rw = 7)) @@ -110,7 +123,6 @@ test_that("estimate_infections works with different kernels", { test_estimate_infections(reported_cases, gp = gp_opts(kernel = "se")) test_estimate_infections(reported_cases, gp = gp_opts(kernel = "ou")) test_estimate_infections(reported_cases, gp = gp_opts(matern_order = 5 / 2)) - expect_error(gp_opts(matern_order = 4)) }) test_that("estimate_infections fails as expected when given a very short timeout", { diff --git a/tests/testthat/test-gp_opts.R b/tests/testthat/test-gp_opts.R new file mode 100644 index 000000000..d91c18a87 --- /dev/null +++ b/tests/testthat/test-gp_opts.R @@ -0,0 +1,44 @@ +test_that("gp_opts returns correct default values", { + gp <- gp_opts() + expect_equal(gp$basis_prop, 0.2) + expect_equal(gp$boundary_scale, 1.5) + expect_equal(gp$ls_mean, 21) + expect_equal(gp$ls_sd, 7) + expect_equal(gp$ls_min, 0) + expect_equal(gp$ls_max, 60) + expect_equal(gp$alpha_sd, 0.01) + expect_equal(gp$kernel, "matern") + expect_equal(gp$matern_order, 3 / 2) + expect_equal(gp$w0, 1.0) +}) + +test_that("gp_opts sets matern_order to Inf for squared exponential kernel", { + gp <- gp_opts(kernel = "se") + expect_equal(gp$matern_order, Inf) +}) + +test_that("gp_opts sets matern_order to 1/2 for Ornstein-Uhlenbeck kernel", { + gp <- gp_opts(kernel = "ou") + expect_equal(gp$matern_order, 1 / 2) +}) + +test_that("gp_opts warns for uncommon Matern kernel orders", { + expect_warning(gp_opts(matern_order = 2), "Uncommon Matern kernel order") +}) + +test_that("gp_opts handles deprecated matern_type parameter", { + lifecycle::expect_deprecated(gp_opts(matern_type = 5 / 2)) + gp <- gp_opts(matern_type = 5 / 2) + expect_equal(gp$matern_order, 5 / 2) +}) + +test_that("gp_opts stops for incompatible matern_order and matern_type", { + expect_error( + gp_opts(matern_order = 3 / 2, matern_type = 5 / 2), + "Incompatible `matern_order` and `matern_type`" + ) +}) + +test_that("gp_opts warns about uncommon Matern kernel orders", { + expect_warning(gp_opts(matern_order = 2), "Uncommon Matern kernel order") +}) diff --git a/tests/testthat/test-stan-convole.R b/tests/testthat/test-stan-convole.R index 5eb246bcb..d249da272 100644 --- a/tests/testthat/test-stan-convole.R +++ b/tests/testthat/test-stan-convole.R @@ -30,9 +30,9 @@ test_that("convolve_with_rev_pmf can combine two pmfs as expected", { test_that("convolve_with_rev_pmf performs the same as a numerical convolution", { # Sample and analytical PMFs for two Poisson distributions - x <- rpois(10000, 3) + x <- rpois(100000, 3) xpmf <- dpois(0:20, 3) - y <- rpois(10000, 5) + y <- rpois(100000, 5) ypmf <- dpois(0:20, 5) # Add sampled Poisson distributions up to get combined distribution z <- x + y diff --git a/tests/testthat/test-stan-guassian-process.R b/tests/testthat/test-stan-guassian-process.R new file mode 100644 index 000000000..ac56950e9 --- /dev/null +++ b/tests/testthat/test-stan-guassian-process.R @@ -0,0 +1,142 @@ +skip_on_cran() +skip_on_os("windows") + +# Helper functions +linspaced_vector <- function(n, start, end) { + seq(start, end, length.out = n) +} + +to_vector <- function(x) { + as.vector(x) +} + +test_that("diagSPD_EQ returns correct dimensions and values", { + alpha <- 1.0 + rho <- 2.0 + L <- 1.0 + M <- 5 + result <- diagSPD_EQ(alpha, rho, L, M) + expect_equal(length(result), M) + expect_true(all(result > 0)) # Expect spectral density to be positive + # Check specific values for known inputs + indices <- linspaced_vector(M, 1, M) + expected_result <- alpha * sqrt(sqrt(2 * pi) * rho) * exp(-0.25 * (rho * pi / (2 * L))^2 * indices^2) + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("diagSPD_Matern returns correct dimensions and values", { + nu <- 1.5 + alpha <- 1.0 + rho <- 2.0 + L <- 1.0 + M <- 5 + result <- diagSPD_Matern(nu, alpha, rho, L, M) + expect_equal(length(result), M) + expect_true(all(result > 0)) # Expect spectral density to be positive + # Check specific values for known inputs + indices <- linspaced_vector(M, 1, M) + factor <- 2 * alpha * (sqrt(2 * nu) / rho)^nu + denom <- (sqrt(2 * nu) / rho)^2 + (pi / (2 * L) * indices)^(nu + 0.5) + expected_result <- factor / denom + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("diagSPD_Periodic returns correct dimensions and values", { + alpha <- 1.0 + rho <- 2.0 + M <- 5 + result <- diagSPD_Periodic(alpha, rho, M) + expect_equal(length(result), 2 * M) # Expect double the dimensions due to append_row + expect_true(all(result > 0)) # Expect spectral density to be positive +}) + +test_that("PHI returns correct dimensions and values", { + N <- 5 + M <- 3 + L <- 1.0 + x <- seq(0, 1, length.out = N) + result <- PHI(N, M, L, x) + expect_equal(dim(result), c(N, M)) + # Check specific values for known inputs + expected_result <- sin(outer(x + L, 1:M, function(x, m) pi / (2 * L) * x * m)) / sqrt(L) + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("PHI_periodic returns correct dimensions and values", { + N <- 5 + M <- 3 + w0 <- 1.0 + x <- seq(0, 1, length.out = N) + result <- PHI_periodic(N, M, w0, x) + expect_equal(dim(result), c(N, 2 * M)) # Cosine and sine terms + # Check specific values for known inputs + mw0x <- outer(w0 * x, 1:M, function(x, m) x * m) + expected_result <- cbind(cos(mw0x), sin(mw0x)) + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("setup_noise returns correct count of noise terms", { + ot_h <- 10 + t <- 10 + horizon <- 0 + estimate_r <- 1 + stationary <- 1 + future_fixed <- 0 + fixed_from <- 0 + result <- setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from) + expect_equal(result, ot_h) + # Test with different parameters + result <- setup_noise(ot_h, t, horizon, estimate_r, 0, future_fixed, fixed_from) + expect_equal(result, ot_h - 1) + result <- setup_noise(ot_h, t, horizon, 0, stationary, future_fixed, fixed_from) + expect_equal(result, t) + result <- setup_noise(ot_h, t, 2, estimate_r, stationary, 1, 3) + expect_equal(result, ot_h - 2 + 3) +}) + +test_that("setup_gp returns correct dimensions and values", { + M <- 3 + L <- 1.0 + dimension <- 5 + is_periodic <- 0 + w0 <- 1.0 + result <- setup_gp(M, L, dimension, is_periodic, w0) + expect_equal(dim(result), c(dimension, M)) + # Compare with direct PHI call + x <- linspaced_vector(dimension, 1, dimension) + x <- (x - mean(x)) / sd(x) + expected_result <- PHI(dimension, M, L, x) + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("setup_gp with periodic basis functions returns correct dimensions and values", { + M <- 3 + L <- 1.0 + dimension <- 5 + is_periodic <- 1 + w0 <- 1.0 + result <- setup_gp(M, L, dimension, is_periodic, w0) + expect_equal(dim(result), c(dimension, 2 * M)) # Cosine and sine terms + # Compare with direct PHI_periodic call + x <- linspaced_vector(dimension, 1, dimension) + x <- (x - mean(x)) / sd(x) + expected_result <- PHI_periodic(dimension, M, w0, x) + expect_equal(result, expected_result, tolerance = 1e-8) +}) + +test_that("update_gp returns correct dimensions and values", { + M <- 3 + L <- 1.0 + alpha <- 1.0 + rho <- 2.0 + eta <- rep(1, M) + PHI <- matrix(runif(15), nrow = 5) # 5 observations, 3 basis functions + type <- 0 + nu <- 1.5 # Not used in SE case + result <- update_gp(PHI, M, L, alpha, rho, eta, type, nu) + expect_equal(length(result), nrow(PHI)) # Should match number of observations + # Check specific values for known inputs + diagSPD <- diagSPD_EQ(alpha, rho, L, M) + expected_result <- PHI %*% (diagSPD * eta) + expect_equal(matrix(result, ncol = 1), expected_result, tolerance = 1e-8) +})