Skip to content

Commit

Permalink
Additional GP kernels (#741)
Browse files Browse the repository at this point in the history
* lint GP code

* add more kernels and speed up computation

* add ou explicitly

* linting

* add tests

* add news item

* update GP vignette

* linting / fixing bits

* spelling
  • Loading branch information
sbfnk authored Aug 9, 2024
1 parent 22e5b22 commit d803f1f
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 36 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk.
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.

## Bug fixes

Expand Down
7 changes: 4 additions & 3 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,10 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_max = data$t - data$seeding_time - data$horizon,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "matern", 1,
default = 0
is.infinite(gp$matern_order), 0,
gp$matern_order == 1 / 2, 1,
gp$matern_order == 3 / 2, 2,
default = 3
)
)

Expand Down
60 changes: 48 additions & 12 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,18 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' the expected standard deviation of the logged Rt.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the squared exponential kernel ("se") and the 3 over 2 Matern
#' kernel ("matern", with `matern_type = 3/2`). Defaulting to the Matern 3 over
#' 2 kernel as discontinuities are expected in Rt and infections.
#' supporting the squared exponential kernel ("se", or "matern" with
#' 'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with
#' `matern_order = 3/2` (default) or `matern_order = 5/2`, respectively), or
#' Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting
#' to the Matérn 3 over 2 kernel for a balance of smoothness and
#' discontinuities.
#'
#' @param matern_type Numeric, defaults to 3/2. Type of Matern Kernel to use.
#' Currently only the Matern 3/2 kernel is supported.
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param matern_type Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel
#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
Expand Down Expand Up @@ -456,8 +462,41 @@ gp_opts <- function(basis_prop = 0.2,
ls_min = 0,
ls_max = 60,
alpha_sd = 0.05,
kernel = c("matern_3/2", "se"),
matern_type = 3 / 2) {
kernel = c("matern", "se", "ou"),
matern_order = 3 / 2,
matern_type) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type == matern_order) {
stop(
"Incompatible `matern_order` and `matern_type`. ",
"Use `matern_order` only."
)
}
matern_order <- matern_type
}

kernel <- arg_match(kernel)
if (kernel == "se") {
if (!missing(matern_order) && is.finite(matern_order)) {
stop("Squared exponential kernel must have matern order unset or `Inf`.")
}
matern_order <- Inf
} else if (kernel == "ou") {
if (!missing(matern_order) && matern_order != 1 / 2) {
stop("Ornstein-Uhlenbeck kernel must have matern order unset or `1 / 2`.") ## nolint: nonportable_path_linter
}
matern_order <- 1 / 2
} else if (!(is.infinite(matern_order) ||
matern_order %in% c(1 / 2, 3 / 2, 5 / 2))) {
stop(
"only the Matern kernels of order `1 / 2`, `3 / 2`, `5 / 2` or `Inf` ", ## nolint: nonportable_path_linter
"are currently supported"
)
}

gp <- list(
basis_prop = basis_prop,
boundary_scale = boundary_scale,
Expand All @@ -466,13 +505,10 @@ gp_opts <- function(basis_prop = 0.2,
ls_min = ls_min,
ls_max = ls_max,
alpha_sd = alpha_sd,
kernel = arg_match(kernel),
matern_type = matern_type
kernel = kernel,
matern_order = matern_order
)

if (gp$matern_type != 3 / 2) {
stop("only the Matern 3/2 kernel is currently supported") # nolint
}
attr(gp, "class") <- c("gp_opts", class(gp))
return(gp)
}
Expand Down
4 changes: 3 additions & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ transformed data{
int ot = t - seeding_time - horizon; // observed time
int ot_h = ot + horizon; // observed time + forecast horizon
// gaussian process
int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from);
int noise_terms = setup_noise(
ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from
);
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function
// Rt
real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2));
Expand Down
49 changes: 40 additions & 9 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,56 @@
// see here for details: https://arxiv.org/pdf/2004.11408.pdf
real lambda(real L, int m) {
real lam;
lam = ((m*pi())/(2*L))^2;
lam = ((m * pi())/(2 * L))^2;
return lam;
}

// eigenfunction for approximate hilbert space gp
// see here for details: https://arxiv.org/pdf/2004.11408.pdf
vector phi(real L, int m, vector x) {
vector[rows(x)] fi;
fi = 1/sqrt(L) * sin(m*pi()/(2*L) * (x+L));
fi = 1/sqrt(L) * sin(m * pi()/(2 * L) * (x + L));
return fi;
}

// spectral density of the exponential quadratic kernal
real spd_se(real alpha, real rho, real w) {
real S;
S = (alpha^2) * sqrt(2*pi()) * rho * exp(-0.5*(rho^2)*(w^2));
// S = (alpha^2) * sqrt(2 * pi()) * rho * exp(-0.5 * (rho^2) * (w^2));
S = 2.506628 * alpha * rho * exp(-0.5 * (rho^2) * (w^2));
return S;
}

// spectral density of the Ornstein-Uhlenbeck kernal
real spd_ou(real alpha, real rho, real w) {
real S;
S = 2 * alpha * rho / (1 + rho^2 * w^2);
return S;
}
// spectral density of the Matern 3/2 kernel
real spd_matern(real alpha, real rho, real w) {
real spd_matern32(real alpha, real rho, real w) {
real S;
// S = 4 * alpha^2 * (sqrt(3) / rho)^3 * 1 / ((sqrt(3) / rho)^2 + w^2)^2;
S = 20.78461 * alpha / (rho^3 * (3 / rho^2 + w^2)^2);
return S;
}

real spd_matern52(real alpha, real rho, real w) {
real S;
S = 4*alpha^2 * (sqrt(3)/rho)^3 * 1/((sqrt(3)/rho)^2 + w^2)^2;
// S = 16/3 * alpha^2 * (sqrt(5) / rho)^5 * 1 / ((sqrt(5) / rho)^2 + w^2)^3
S = 298.1424 * alpha / (rho^5 * (5 / rho^2 + w^2)^3);
return S;
}

// setup gaussian process noise dimensions
int setup_noise(int ot_h, int t, int horizon, int estimate_r,
int stationary, int future_fixed, int fixed_from) {
int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t;
int noise_terms = future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
int noise_terms =
future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
return(noise_terms);
}

// setup approximate gaussian process
matrix setup_gp(int M, real L, int dimension) {
vector[dimension] time;
Expand All @@ -44,6 +65,7 @@ matrix setup_gp(int M, real L, int dimension) {
}
return(PHI);
}

// update gaussian process using spectral densities
vector update_gp(matrix PHI, int M, real L, real alpha,
real rho, vector eta, int type) {
Expand All @@ -57,15 +79,24 @@ vector update_gp(matrix PHI, int M, real L, real alpha,
for(m in 1:M){
diagSPD[m] = sqrt(spd_se(alpha, unit_rho, sqrt(lambda(L, m))));
}
}else if (type == 1) {
} else if (type == 1) {
for(m in 1:M){
diagSPD[m] = sqrt(spd_ou(alpha, unit_rho, sqrt(lambda(L, m))));
}
} else if (type == 2) {
for(m in 1:M){
diagSPD[m] = sqrt(spd_matern32(alpha, unit_rho, sqrt(lambda(L, m))));
}
} else if (type == 3) {
for(m in 1:M){
diagSPD[m] = sqrt(spd_matern(alpha, unit_rho, sqrt(lambda(L, m))));
diagSPD[m] = sqrt(spd_matern52(alpha, unit_rho, sqrt(lambda(L, m))));
}
}
SPD_eta = diagSPD .* eta;
noise = noise + PHI[,] * SPD_eta;
return(noise);
}

// priors for gaussian process
void gaussian_process_lp(real rho, real alpha, vector eta,
real ls_meanlog, real ls_sdlog,
Expand All @@ -75,6 +106,6 @@ void gaussian_process_lp(real rho, real alpha, vector eta,
} else {
rho ~ inv_gamma(1.499007, 0.057277 * ls_max) T[ls_min, ls_max];
}
alpha ~ normal(0, alpha_sd);
alpha ~ normal(0, alpha_sd) T[0,];
eta ~ std_normal();
}
2 changes: 1 addition & 1 deletion inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps,
if (t > gp_n) {
gp[(gp_n + 1):t] = rep_vector(noise[gp_n], t - gp_n);
}
}else{
} else {
gp[2:(gp_n + 1)] = noise;
gp = cumulative_sum(gp);
}
Expand Down
21 changes: 14 additions & 7 deletions man/gp_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ test_that("estimate_infections works without setting a generation time", {
expect_equal(exp(combined$growth_rate), combined$R)
})

test_that("estimate_infections works with different kernels", {
skip_on_cran()
test_estimate_infections(reported_cases, gp = gp_opts(kernel = "se"))
test_estimate_infections(reported_cases, gp = gp_opts(kernel = "ou"))
test_estimate_infections(reported_cases, gp = gp_opts(matern_order = 5 / 2))
expect_error(gp_opts(matern_order = 4))
})

test_that("estimate_infections fails as expected when given a very short timeout", {
skip_on_cran()
expect_error(output <- capture.output(suppressMessages(
Expand Down
21 changes: 18 additions & 3 deletions vignettes/gaussian_process_implementation_details.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,34 @@ In our case as set out above, we have
k(t,t') = k(|t - t'|) = k(\Delta t)
\end{equation}

where by default $k$ is a Matern 3/2 covariance kernel,
with the following choices available for the kernel $k$

## Matérn 3/2 covariance kernel (the default)

\begin{equation}
k(\Delta t) = \alpha \left( 1 + \frac{\sqrt{3} \Delta t}{l} \right) \exp \left( - \frac{\sqrt{3} \Delta t}{l}\right)
\end{equation}

with $l>0$ and $\alpha > 0$ the length scale and magnitude, respectively, of the kernel.
Alternatively, a squared exponential kernel can be chosen to constrain the GP to be smoother.

## Squared exponential kernel

\begin{equation}
k(\Delta t) = \alpha \exp \left( - \frac{1}{2} \frac{(\Delta t^2)}{l^2} \right)
\end{equation}

## Ornstein-Uhlenbeck (Matérn 1/2) kernel

\begin{equation}
k(\Delta t) = \alpha \exp{\left( - \frac{\Delta t}{2 l^2} \right)}
\end{equation}

## Matérn 5/2 covariance kernel

\begin{equation}
k(\Delta t) = \alpha \left( 1 + \frac{\sqrt{5} \Delta t}{l} + \frac{5}{3} \left(\frac{\Delta t}{l} \right)^2 \right) \exp \left( - \frac{\sqrt{5} \Delta t}{l}\right)
\end{equation}

# Hilbert space approximation

In order to make our models computationally tractable, we approximate the Gaussian Process using a Hilbert space approximation to the Gaussian Process [@approxGP], centered around mean zero.
Expand All @@ -78,7 +93,7 @@ where $L$ is a positive number termed boundary condition, and $\beta_{j}$ are re
\beta_j \sim \mathcal{Normal}(0, 1)
\end{equation}

The function $S_k(x)$ is the spectral density relating to a particular covariance function $k$. In the case of the Matern 3/2 kernel (the default in `EpiNow2`) this is given by
The function $S_k(x)$ is the spectral density relating to a particular covariance function $k$. In the case of the Matérn 3/2 kernel (the default in `EpiNow2`) this is given by

\begin{equation}
S_k(x) = 4 \alpha^2 \left( \frac{\sqrt{3}}{\rho}\right)^3 \left(\left( \frac{\sqrt{3}}{\rho} \right)^2 + w^2 \right)^{-2}
Expand Down

0 comments on commit d803f1f

Please sign in to comment.