From 338829836f6da0a203b4d977c4fdde5475eaa2c4 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 19 Dec 2024 18:27:13 +0000 Subject: [PATCH] move lengthscale prior to `dist_spec` (#890) * move rescaled_rho to params framework * update stan model * update tests * update docs * add news item * scale in model * update test * update news * remove obsolete tests --- NEWS.md | 3 +- R/create.R | 31 ++--------- R/opts.R | 54 ++++++++++++++----- R/simulate_infections.R | 1 + .../stan/data/estimate_infections_params.stan | 3 +- inst/stan/data/gaussian_process.stan | 4 -- inst/stan/estimate_infections.stan | 16 +++--- inst/stan/functions/gaussian_process.stan | 32 +++-------- man/gp_opts.Rd | 21 ++++---- tests/testthat/test-create_gp_data.R | 16 +----- tests/testthat/test-gp_opts.R | 5 +- tests/testthat/test-stan-guassian-process.R | 4 +- 12 files changed, 79 insertions(+), 111 deletions(-) diff --git a/NEWS.md b/NEWS.md index 594403165..7e6a5030f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -17,7 +17,8 @@ - A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs. - A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk. -- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs. +- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs. +- The Gaussian Process lengthscale is now scaled internally by half the length of the time series. By @sbfnk in #890 and reviewed by #seabbs. - A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @. ## Package changes diff --git a/R/create.R b/R/create.R index 789bbb444..c4218fa93 100644 --- a/R/create.R +++ b/R/create.R @@ -362,19 +362,6 @@ create_gp_data <- function(gp = gp_opts(), data) { time <- time - 1 } - obs_time <- data$t - data$seeding_time - if (gp$ls_max > obs_time) { - gp$ls_max <- obs_time - } - - times <- seq_len(time) - - rescaled_times <- (times - mean(times)) / sd(times) - gp$ls_mean <- gp$ls_mean / sd(times) - gp$ls_sd <- gp$ls_sd / sd(times) - gp$ls_min <- gp$ls_min / sd(times) - gp$ls_max <- gp$ls_max / sd(times) - # basis functions M <- ceiling(time * gp$basis_prop) @@ -382,11 +369,7 @@ create_gp_data <- function(gp = gp_opts(), data) { gp_data <- list( fixed = as.numeric(fixed), M = M, - L = gp$boundary_scale * max(rescaled_times), - ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd), - ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd), - ls_min = gp$ls_min, - ls_max = gp$ls_max, + L = gp$boundary_scale, gp_type = data.table::fcase( gp$kernel == "se", 0, gp$kernel == "periodic", 1, @@ -564,11 +547,13 @@ create_stan_data <- function(data, seeding_time, stan_data, create_stan_params( alpha = gp$alpha, + rho = gp$ls, R0 = rt$prior, frac_obs = obs$scale, rep_phi = obs$phi, lower_bounds = c( alpha = 0, + rho = 0, R0 = 0, frac_obs = 0, rep_phi = 0 @@ -623,18 +608,8 @@ create_initial_conditions <- function(data) { if (data$fixed == 0) { out$eta <- array(rnorm( ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1)) - out$rescaled_rho <- array(rlnorm(1, - meanlog = data$ls_meanlog, - sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01) - )) - out$rescaled_rho <- array(data.table::fcase( - out$rescaled_rho > data$ls_max, data$ls_max - 0.001, - out$rescaled_rho < data$ls_min, data$ls_min + 0.001, - default = out$rescaled_rho - )) } else { out$eta <- array(numeric(0)) - out$rescaled_rho <- array(numeric(0)) } if (data$estimate_r == 1) { out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2)) diff --git a/R/opts.R b/R/opts.R index 613a6452f..b1aacd852 100644 --- a/R/opts.R +++ b/R/opts.R @@ -465,19 +465,19 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' Defines a list specifying the structure of the approximate Gaussian #' process. Custom settings can be supplied which override the defaults. #' -#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal -#' length scale. +#' @param ls_mean Deprecated; use `ls` instead. #' -#' @param ls_sd Numeric, defaults to 7 days. The standard deviation of the log -#' normal length scale. If \code{ls_sd = 0}, inverse-gamma prior on Gaussian -#' process length scale will be used with recommended parameters -#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}. +#' @param ls_sd Deprecated; use `ls` instead. #' -#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale. +#' @param ls_min Deprecated; use `ls` instead. #' -#' @param ls_max Numeric, defaults to 60. The maximum value of the length -#' scale. Updated in [create_gp_data()] to be the length of the input data if -#' this is smaller. +#' @param ls_max Deprecated; use `ls` instead. +#' +#' @param ls A `` giving the prior distribution of the lengthscale +#' parameter of the Gaussian process kernel on the scale of days. Defaults to +#' a Lognormal distribution with mean 21 days, sd 7 days and maximum 60 days: +#' `LogNormal(mean = 21, sd = 7, max = 60)` (a lower limit of 0 will be +#' enforced automatically to ensure positivity) #' #' @param alpha A `` giving the prior distribution of the magnitude #' parameter of the Gaussian process kernel. Should be approximately the @@ -537,6 +537,7 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = 7, ls_min = 0, ls_max = 60, + ls = LogNormal(mean = 21, sd = 7, max = 60), alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3 / 2, @@ -559,6 +560,34 @@ gp_opts <- function(basis_prop = 0.2, "1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)" ) } + if (!missing(ls_mean) || !missing(ls_sd) || !missing(ls_min) || + !missing(ls_max)) { + if (!missing(ls)) { + cli_abort( + c( + "!" = "Both {.var ls} and at least one legacy argument + ({.var ls_mean}, {.var ls_sd}, {.var ls_min}, {.var ls_max}) have been + specified.", + "i" = "Only one of the should be used." + ) + ) + } + cli_warn(c( + "!" = "Specifying lengthscale priors via the {.var ls_mean}, {.var ls_sd}, + {.var ls_min}, and {.var ls_max} arguments is deprecated.", + "i" = "Use the {.var ls} argument instead." + )) + if (ls_min > 0) { + cli_abort( + c( + "!" = "Lower lengthscale bounds of greater than 0 are no longer + supported. If this is a feature you need please open an Issue on the + EpiNow2 GitHub repository." + ) + ) + } + ls <- LogNormal(mean = ls_mean, sd = ls_sd, max = ls_max) + } if (!missing(matern_type)) { if (!missing(matern_order) && matern_type != matern_order) { @@ -592,10 +621,7 @@ gp_opts <- function(basis_prop = 0.2, gp <- list( basis_prop = basis_prop, boundary_scale = boundary_scale, - ls_mean = ls_mean, - ls_sd = ls_sd, - ls_min = ls_min, - ls_max = ls_max, + ls = ls, alpha = alpha, kernel = kernel, matern_order = matern_order, diff --git a/R/simulate_infections.R b/R/simulate_infections.R index f3b66cc3f..c28daa6cc 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -176,6 +176,7 @@ simulate_infections <- function(estimates, R, initial_infections, data <- c(data, create_stan_params( alpha = NULL, + rho = NULL, R0 = NULL, frac_obs = obs$scale, rep_phi = obs$phi diff --git a/inst/stan/data/estimate_infections_params.stan b/inst/stan/data/estimate_infections_params.stan index 3351f5ea3..85be5c1c9 100644 --- a/inst/stan/data/estimate_infections_params.stan +++ b/inst/stan/data/estimate_infections_params.stan @@ -1,4 +1,5 @@ int alpha_id; // parameter id of alpha (GP magnitude) -int R0_id; // parameter id of R0 +int rho_id; // parameter id of rho (GP lengthscale) +int R0_id; // parameter id of R0 int frac_obs_id; // parameter id of frac_obs int rep_phi_id; // parameter id of rep_phi_id diff --git a/inst/stan/data/gaussian_process.stan b/inst/stan/data/gaussian_process.stan index 7990dba8a..9b1fed466 100644 --- a/inst/stan/data/gaussian_process.stan +++ b/inst/stan/data/gaussian_process.stan @@ -1,9 +1,5 @@ real L; // boundary value for infections gp int M; // basis functions for infections gp - real ls_meanlog; // meanlog for gp lengthscale prior - real ls_sdlog; // sdlog for gp lengthscale prior - real ls_min; // Lower bound for the lengthscale - real ls_max; // Upper bound for the lengthscale int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern real nu; // smoothness parameter for Matern kernel (used if gp_type = 2) real w0; // fundamental frequency for periodic kernel (used if gp_type = 1) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 8202c962c..e99b04211 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -43,7 +43,6 @@ transformed data { parameters { vector[n_params_variable] params; // gaussian process - array[fixed ? 0 : 1] real rescaled_rho; // length scale of noise GP vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise // Rt array[estimate_r] real initial_infections; // seed infections @@ -70,6 +69,10 @@ transformed parameters { alpha_id, params_fixed_lookup, params_variable_lookup, params_value, params ); + real rescaled_rho = 2 * get_param( + rho_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ) / noise_terms; noise = update_gp( PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu ); @@ -176,9 +179,6 @@ model { if (!fixed) { profile("gp lp") { gaussian_process_lp(eta); - if (gp_type != 3) { - lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max); - } } } @@ -226,16 +226,18 @@ generated quantities { vector[estimate_r > 0 ? 0 : ot_h] gen_R; vector[ot_h - 1] r; vector[return_likelihood ? ot : 0] log_lik; - vector[fixed ? 0 : 1] rho; profile("generated quantities") { real rep_phi = get_param( rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, params ); - if (!fixed && gp_type != 3) { + if (!fixed) { + real rescaled_rho = 2 * get_param( + rho_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ) / noise_terms; vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms); - rho[1] = rescaled_rho[1] * sd(x); } if (estimate_r == 0) { diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index e35906f02..b680961b7 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -143,7 +143,7 @@ int setup_noise(int ot_h, int t, int horizon, int estimate_r, */ matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) { vector[dimension] x = linspaced_vector(dimension, 1, dimension); - x = (x - mean(x)) / sd(x); + x = 2 * (x - mean(x)) / (max(x) - 1); if (is_periodic) { return PHI_periodic(dimension, M, w0, x); } else { @@ -165,21 +165,21 @@ matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) { * @return A vector of updated noise terms */ vector update_gp(matrix PHI, int M, real L, real alpha, - array[] real rho, vector eta, int type, real nu) { + real rho, vector eta, int type, real nu) { vector[type == 1 ? 2 * M : M] diagSPD; // spectral density // GP in noise - spectral densities if (type == 0) { - diagSPD = diagSPD_EQ(alpha, rho[1], L, M); + diagSPD = diagSPD_EQ(alpha, rho, L, M); } else if (type == 1) { - diagSPD = diagSPD_Periodic(alpha, rho[1], M); + diagSPD = diagSPD_Periodic(alpha, rho, M); } else if (type == 2) { if (nu == 0.5) { - diagSPD = diagSPD_Matern12(alpha, rho[1], L, M); + diagSPD = diagSPD_Matern12(alpha, rho, L, M); } else if (nu == 1.5) { - diagSPD = diagSPD_Matern32(alpha, rho[1], L, M); + diagSPD = diagSPD_Matern32(alpha, rho, L, M); } else if (nu == 2.5) { - diagSPD = diagSPD_Matern52(alpha, rho[1], L, M); + diagSPD = diagSPD_Matern52(alpha, rho, L, M); } else { reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu); } @@ -187,24 +187,6 @@ vector update_gp(matrix PHI, int M, real L, real alpha, return PHI * (diagSPD .* eta); } -/** - * Prior for Gaussian process length scale - * - * @param rho Length scale parameter - * @param ls_meanlog Mean of the log of the length scale - * @param ls_sdlog Standard deviation of the log of the length scale - * @param ls_min Minimum length scale - * @param ls_max Maximum length scale - */ -void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog, - real ls_min, real ls_max) { - if (ls_sdlog > 0) { - rho ~ lognormal(ls_meanlog, ls_sdlog) T[ls_min, ls_max]; - } else { - rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max]; - } -} - /** * Priors for Gaussian process (excluding length scale) * diff --git a/man/gp_opts.Rd b/man/gp_opts.Rd index 3bbe91930..955fd47d6 100644 --- a/man/gp_opts.Rd +++ b/man/gp_opts.Rd @@ -11,6 +11,7 @@ gp_opts( ls_sd = 7, ls_min = 0, ls_max = 60, + ls = LogNormal(mean = 21, sd = 7, max = 60), alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3/2, @@ -32,19 +33,19 @@ proportion of basis functions. See (Riutort-Mayol et al. 2020 approximate Gaussian process. See (Riutort-Mayol et al. 2020 \url{https://arxiv.org/abs/2004.11408}) for advice on updating this default.} -\item{ls_mean}{Numeric, defaults to 21 days. The mean of the lognormal -length scale.} +\item{ls_mean}{Deprecated; use \code{ls} instead.} -\item{ls_sd}{Numeric, defaults to 7 days. The standard deviation of the log -normal length scale. If \code{ls_sd = 0}, inverse-gamma prior on Gaussian -process length scale will be used with recommended parameters -\code{inv_gamma(1.499007, 0.057277 * ls_max)}.} +\item{ls_sd}{Deprecated; use \code{ls} instead.} -\item{ls_min}{Numeric, defaults to 0. The minimum value of the length scale.} +\item{ls_min}{Deprecated; use \code{ls} instead.} -\item{ls_max}{Numeric, defaults to 60. The maximum value of the length -scale. Updated in \code{\link[=create_gp_data]{create_gp_data()}} to be the length of the input data if -this is smaller.} +\item{ls_max}{Deprecated; use \code{ls} instead.} + +\item{ls}{A \verb{} giving the prior distribution of the lengthscale +parameter of the Gaussian process kernel on the scale of days. Defaults to +a Lognormal distribution with mean 21 days, sd 7 days and maximum 60 days: +\code{LogNormal(mean = 21, sd = 7, max = 60)} (a lower limit of 0 will be +enforced automatically to ensure positivity)} \item{alpha}{A \verb{} giving the prior distribution of the magnitude parameter of the Gaussian process kernel. Should be approximately the diff --git a/tests/testthat/test-create_gp_data.R b/tests/testthat/test-create_gp_data.R index d33dccd76..604c966e1 100644 --- a/tests/testthat/test-create_gp_data.R +++ b/tests/testthat/test-create_gp_data.R @@ -1,17 +1,10 @@ test_that("create_gp_data returns correct default values when GP is disabled", { data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0) - restricted_time <- 30 - 7 - 1 - times <- seq_len(restricted_time) gp_data <- create_gp_data(NULL, data) expect_equal(gp_data$fixed, 1) expect_equal(gp_data$stationary, 1) expect_equal(gp_data$M, 5) # (30 - 7) * 0.2 - expect_equal(gp_data$L, 2.43, tolerance = 0.01) - expect_equal(gp_data$ls_meanlog, convert_to_logmean(21 / sd(times), 7 / sd(times))) - expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7)) - expect_equal(gp_data$ls_min, 0) - expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01) - expect_equal(gp_data$alpha, NULL) + expect_equal(gp_data$L, 1.5) expect_equal(gp_data$gp_type, 2) # Default to Matern expect_equal(gp_data$nu, 3 / 2) expect_equal(gp_data$w0, 1.0) @@ -37,13 +30,6 @@ test_that("create_gp_data sets correct gp_type and nu for different kernels", { expect_equal(gp_data$nu, 1 / 2) }) -test_that("create_gp_data correctly adjusts ls_max", { - data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0) - gp <- gp_opts(ls_max = 50) - gp_data <- create_gp_data(gp, data) - expect_equal(gp_data$ls_max, 3.39, tolerance = 0.01) # 30 - 7 - 7 -}) - test_that("create_gp_data correctly handles future_fixed", { data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 1, fixed_from = 2, stationary = 0) gp_data <- create_gp_data(gp_opts(), data) diff --git a/tests/testthat/test-gp_opts.R b/tests/testthat/test-gp_opts.R index 47e4f6186..0289972ec 100644 --- a/tests/testthat/test-gp_opts.R +++ b/tests/testthat/test-gp_opts.R @@ -2,11 +2,8 @@ test_that("gp_opts returns correct default values", { gp <- gp_opts() expect_equal(gp$basis_prop, 0.2) expect_equal(gp$boundary_scale, 1.5) - expect_equal(gp$ls_mean, 21) - expect_equal(gp$ls_sd, 7) - expect_equal(gp$ls_min, 0) - expect_equal(gp$ls_max, 60) expect_equal(gp$alpha, Normal(0, 0.01)) + expect_equal(gp$ls, LogNormal(mean = 21, sd = 7, max = 60)) expect_equal(gp$kernel, "matern") expect_equal(gp$matern_order, 3 / 2) expect_equal(gp$w0, 1.0) diff --git a/tests/testthat/test-stan-guassian-process.R b/tests/testthat/test-stan-guassian-process.R index 0a4717065..14936c5fe 100644 --- a/tests/testthat/test-stan-guassian-process.R +++ b/tests/testthat/test-stan-guassian-process.R @@ -126,7 +126,7 @@ test_that("setup_gp returns correct dimensions and values", { expect_equal(dim(result), c(dimension, M)) # Compare with direct PHI call x <- linspaced_vector(dimension, 1, dimension) - x <- (x - mean(x)) / sd(x) + x <- 2 * (x - mean(x)) / (max(x) - 1) expected_result <- PHI(dimension, M, L, x) expect_equal(result, expected_result, tolerance = 1e-8) }) @@ -141,7 +141,7 @@ test_that("setup_gp with periodic basis functions returns correct dimensions and expect_equal(dim(result), c(dimension, 2 * M)) # Cosine and sine terms # Compare with direct PHI_periodic call x <- linspaced_vector(dimension, 1, dimension) - x <- (x - mean(x)) / sd(x) + x <- 2 * (x - mean(x)) / (max(x) - 1) expected_result <- PHI_periodic(dimension, M, w0, x) expect_equal(result, expected_result, tolerance = 1e-8) })