Skip to content

Commit

Permalink
move lengthscale prior to dist_spec (#890)
Browse files Browse the repository at this point in the history
* move rescaled_rho to params framework

* update stan model

* update tests

* update docs

* add news item

* scale in model

* update test

* update news

* remove obsolete tests
  • Loading branch information
sbfnk authored Dec 19, 2024
1 parent dbb6536 commit 3388298
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 111 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs.
- The Gaussian Process lengthscale is now scaled internally by half the length of the time series. By @sbfnk in #890 and reviewed by #seabbs.
- A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @.

## Package changes
Expand Down
31 changes: 3 additions & 28 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,31 +362,14 @@ create_gp_data <- function(gp = gp_opts(), data) {
time <- time - 1
}

obs_time <- data$t - data$seeding_time
if (gp$ls_max > obs_time) {
gp$ls_max <- obs_time
}

times <- seq_len(time)

rescaled_times <- (times - mean(times)) / sd(times)
gp$ls_mean <- gp$ls_mean / sd(times)
gp$ls_sd <- gp$ls_sd / sd(times)
gp$ls_min <- gp$ls_min / sd(times)
gp$ls_max <- gp$ls_max / sd(times)

# basis functions
M <- ceiling(time * gp$basis_prop)

# map settings to underlying gp stan requirements
gp_data <- list(
fixed = as.numeric(fixed),
M = M,
L = gp$boundary_scale * max(rescaled_times),
ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd),
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
L = gp$boundary_scale,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
Expand Down Expand Up @@ -564,11 +547,13 @@ create_stan_data <- function(data, seeding_time,
stan_data,
create_stan_params(
alpha = gp$alpha,
rho = gp$ls,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
alpha = 0,
rho = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
Expand Down Expand Up @@ -623,18 +608,8 @@ create_initial_conditions <- function(data) {
if (data$fixed == 0) {
out$eta <- array(rnorm(
ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1))
out$rescaled_rho <- array(rlnorm(1,
meanlog = data$ls_meanlog,
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01)
))
out$rescaled_rho <- array(data.table::fcase(
out$rescaled_rho > data$ls_max, data$ls_max - 0.001,
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))
} else {
out$eta <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
Expand Down
54 changes: 40 additions & 14 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,19 +465,19 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
#' @param ls_mean Deprecated; use `ls` instead.
#'
#' @param ls_sd Numeric, defaults to 7 days. The standard deviation of the log
#' normal length scale. If \code{ls_sd = 0}, inverse-gamma prior on Gaussian
#' process length scale will be used with recommended parameters
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}.
#' @param ls_sd Deprecated; use `ls` instead.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' @param ls_min Deprecated; use `ls` instead.
#'
#' @param ls_max Numeric, defaults to 60. The maximum value of the length
#' scale. Updated in [create_gp_data()] to be the length of the input data if
#' this is smaller.
#' @param ls_max Deprecated; use `ls` instead.
#'
#' @param ls A `<dist_spec>` giving the prior distribution of the lengthscale
#' parameter of the Gaussian process kernel on the scale of days. Defaults to
#' a Lognormal distribution with mean 21 days, sd 7 days and maximum 60 days:
#' `LogNormal(mean = 21, sd = 7, max = 60)` (a lower limit of 0 will be
#' enforced automatically to ensure positivity)
#'
#' @param alpha A `<dist_spec>` giving the prior distribution of the magnitude
#' parameter of the Gaussian process kernel. Should be approximately the
Expand Down Expand Up @@ -537,6 +537,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
ls = LogNormal(mean = 21, sd = 7, max = 60),
alpha = Normal(mean = 0, sd = 0.01),
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
Expand All @@ -559,6 +560,34 @@ gp_opts <- function(basis_prop = 0.2,
"1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)"
)
}
if (!missing(ls_mean) || !missing(ls_sd) || !missing(ls_min) ||
!missing(ls_max)) {
if (!missing(ls)) {
cli_abort(
c(
"!" = "Both {.var ls} and at least one legacy argument
({.var ls_mean}, {.var ls_sd}, {.var ls_min}, {.var ls_max}) have been
specified.",
"i" = "Only one of the should be used."
)
)
}
cli_warn(c(
"!" = "Specifying lengthscale priors via the {.var ls_mean}, {.var ls_sd},
{.var ls_min}, and {.var ls_max} arguments is deprecated.",
"i" = "Use the {.var ls} argument instead."
))
if (ls_min > 0) {
cli_abort(
c(
"!" = "Lower lengthscale bounds of greater than 0 are no longer
supported. If this is a feature you need please open an Issue on the
EpiNow2 GitHub repository."
)
)
}
ls <- LogNormal(mean = ls_mean, sd = ls_sd, max = ls_max)
}

if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type != matern_order) {
Expand Down Expand Up @@ -592,10 +621,7 @@ gp_opts <- function(basis_prop = 0.2,
gp <- list(
basis_prop = basis_prop,
boundary_scale = boundary_scale,
ls_mean = ls_mean,
ls_sd = ls_sd,
ls_min = ls_min,
ls_max = ls_max,
ls = ls,
alpha = alpha,
kernel = kernel,
matern_order = matern_order,
Expand Down
1 change: 1 addition & 0 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ simulate_infections <- function(estimates, R, initial_infections,

data <- c(data, create_stan_params(
alpha = NULL,
rho = NULL,
R0 = NULL,
frac_obs = obs$scale,
rep_phi = obs$phi
Expand Down
3 changes: 2 additions & 1 deletion inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
int<lower = 0> alpha_id; // parameter id of alpha (GP magnitude)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> rho_id; // parameter id of rho (GP lengthscale)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> frac_obs_id; // parameter id of frac_obs
int<lower = 0> rep_phi_id; // parameter id of rep_phi_id
4 changes: 0 additions & 4 deletions inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
real L; // boundary value for infections gp
int<lower=1> M; // basis functions for infections gp
real ls_meanlog; // meanlog for gp lengthscale prior
real ls_sdlog; // sdlog for gp lengthscale prior
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
real w0; // fundamental frequency for periodic kernel (used if gp_type = 1)
Expand Down
16 changes: 9 additions & 7 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ transformed data {
parameters {
vector<lower = params_lower, upper = params_upper>[n_params_variable] params;
// gaussian process
array[fixed ? 0 : 1] real<lower = ls_min, upper = ls_max> rescaled_rho; // length scale of noise GP
vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise
// Rt
array[estimate_r] real initial_infections; // seed infections
Expand All @@ -70,6 +69,10 @@ transformed parameters {
alpha_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
real rescaled_rho = 2 * get_param(
rho_id, params_fixed_lookup, params_variable_lookup,
params_value, params
) / noise_terms;
noise = update_gp(
PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu
);
Expand Down Expand Up @@ -176,9 +179,6 @@ model {
if (!fixed) {
profile("gp lp") {
gaussian_process_lp(eta);
if (gp_type != 3) {
lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max);
}
}
}

Expand Down Expand Up @@ -226,16 +226,18 @@ generated quantities {
vector[estimate_r > 0 ? 0 : ot_h] gen_R;
vector[ot_h - 1] r;
vector[return_likelihood ? ot : 0] log_lik;
vector[fixed ? 0 : 1] rho;

profile("generated quantities") {
real rep_phi = get_param(
rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
if (!fixed && gp_type != 3) {
if (!fixed) {
real rescaled_rho = 2 * get_param(
rho_id, params_fixed_lookup, params_variable_lookup,
params_value, params
) / noise_terms;
vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms);
rho[1] = rescaled_rho[1] * sd(x);
}

if (estimate_r == 0) {
Expand Down
32 changes: 7 additions & 25 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ int setup_noise(int ot_h, int t, int horizon, int estimate_r,
*/
matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
vector[dimension] x = linspaced_vector(dimension, 1, dimension);
x = (x - mean(x)) / sd(x);
x = 2 * (x - mean(x)) / (max(x) - 1);
if (is_periodic) {
return PHI_periodic(dimension, M, w0, x);
} else {
Expand All @@ -165,46 +165,28 @@ matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
* @return A vector of updated noise terms
*/
vector update_gp(matrix PHI, int M, real L, real alpha,
array[] real rho, vector eta, int type, real nu) {
real rho, vector eta, int type, real nu) {
vector[type == 1 ? 2 * M : M] diagSPD; // spectral density

// GP in noise - spectral densities
if (type == 0) {
diagSPD = diagSPD_EQ(alpha, rho[1], L, M);
diagSPD = diagSPD_EQ(alpha, rho, L, M);
} else if (type == 1) {
diagSPD = diagSPD_Periodic(alpha, rho[1], M);
diagSPD = diagSPD_Periodic(alpha, rho, M);
} else if (type == 2) {
if (nu == 0.5) {
diagSPD = diagSPD_Matern12(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern12(alpha, rho, L, M);
} else if (nu == 1.5) {
diagSPD = diagSPD_Matern32(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern32(alpha, rho, L, M);
} else if (nu == 2.5) {
diagSPD = diagSPD_Matern52(alpha, rho[1], L, M);
diagSPD = diagSPD_Matern52(alpha, rho, L, M);
} else {
reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu);
}
}
return PHI * (diagSPD .* eta);
}

/**
* Prior for Gaussian process length scale
*
* @param rho Length scale parameter
* @param ls_meanlog Mean of the log of the length scale
* @param ls_sdlog Standard deviation of the log of the length scale
* @param ls_min Minimum length scale
* @param ls_max Maximum length scale
*/
void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog,
real ls_min, real ls_max) {
if (ls_sdlog > 0) {
rho ~ lognormal(ls_meanlog, ls_sdlog) T[ls_min, ls_max];
} else {
rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max];
}
}

/**
* Priors for Gaussian process (excluding length scale)
*
Expand Down
21 changes: 11 additions & 10 deletions man/gp_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 1 addition & 15 deletions tests/testthat/test-create_gp_data.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
test_that("create_gp_data returns correct default values when GP is disabled", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0)
restricted_time <- 30 - 7 - 1
times <- seq_len(restricted_time)
gp_data <- create_gp_data(NULL, data)
expect_equal(gp_data$fixed, 1)
expect_equal(gp_data$stationary, 1)
expect_equal(gp_data$M, 5) # (30 - 7) * 0.2
expect_equal(gp_data$L, 2.43, tolerance = 0.01)
expect_equal(gp_data$ls_meanlog, convert_to_logmean(21 / sd(times), 7 / sd(times)))
expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7))
expect_equal(gp_data$ls_min, 0)
expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01)
expect_equal(gp_data$alpha, NULL)
expect_equal(gp_data$L, 1.5)
expect_equal(gp_data$gp_type, 2) # Default to Matern
expect_equal(gp_data$nu, 3 / 2)
expect_equal(gp_data$w0, 1.0)
Expand All @@ -37,13 +30,6 @@ test_that("create_gp_data sets correct gp_type and nu for different kernels", {
expect_equal(gp_data$nu, 1 / 2)
})

test_that("create_gp_data correctly adjusts ls_max", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 0, fixed_from = 0, stationary = 0)
gp <- gp_opts(ls_max = 50)
gp_data <- create_gp_data(gp, data)
expect_equal(gp_data$ls_max, 3.39, tolerance = 0.01) # 30 - 7 - 7
})

test_that("create_gp_data correctly handles future_fixed", {
data <- list(t = 30, seeding_time = 7, horizon = 7, future_fixed = 1, fixed_from = 2, stationary = 0)
gp_data <- create_gp_data(gp_opts(), data)
Expand Down
5 changes: 1 addition & 4 deletions tests/testthat/test-gp_opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ test_that("gp_opts returns correct default values", {
gp <- gp_opts()
expect_equal(gp$basis_prop, 0.2)
expect_equal(gp$boundary_scale, 1.5)
expect_equal(gp$ls_mean, 21)
expect_equal(gp$ls_sd, 7)
expect_equal(gp$ls_min, 0)
expect_equal(gp$ls_max, 60)
expect_equal(gp$alpha, Normal(0, 0.01))
expect_equal(gp$ls, LogNormal(mean = 21, sd = 7, max = 60))
expect_equal(gp$kernel, "matern")
expect_equal(gp$matern_order, 3 / 2)
expect_equal(gp$w0, 1.0)
Expand Down
Loading

0 comments on commit 3388298

Please sign in to comment.