Skip to content

Commit

Permalink
Merge branch 'main' into vectorise-spectral-density
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Aug 28, 2024
2 parents 13f7f8a + 903cc8b commit 7ff2590
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 65 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
- The default stan control options have been updated from `list(adapt_delta = 0.95, max_treedepth = 15)` to `list(adapt_delta = 0.9, max_treedepth = 12)` due to improved performance and to reduce the runtime of the default parameterisations. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Initialisation has been simplified by sampling directly from the priors, where possible, rather than from a constrained space. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Unnecessary normalisation of delay priors has been removed. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.

Expand Down
114 changes: 89 additions & 25 deletions inst/stan/functions/convolve.stan
Original file line number Diff line number Diff line change
@@ -1,36 +1,100 @@
// convolve two vectors as a backwards dot product
// y vector should be reversed
// limited to the length of x and backwards looking for x indexes
/**
* Calculate convolution indices for the case where s <= xlen
*
* @param s Current position in the output vector
* @param xlen Length of the x vector
* @param ylen Length of the y vector
* @return An array of integers: {start_x, end_x, start_y, end_y}
*/
array[] int calc_conv_indices_xlen(int s, int xlen, int ylen) {
int s_minus_ylen = s - ylen;
int start_x = max(1, s_minus_ylen + 1);
int end_x = s;
int start_y = max(1, 1 - s_minus_ylen);
int end_y = ylen;
return {start_x, end_x, start_y, end_y};
}

/**
* Calculate convolution indices for the case where s > xlen
*
* @param s Current position in the output vector
* @param xlen Length of the x vector
* @param ylen Length of the y vector
* @return An array of integers: {start_x, end_x, start_y, end_y}
*/
array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
int s_minus_ylen = s - ylen;
int start_x = max(1, s_minus_ylen + 1);
int end_x = xlen;
int start_y = max(1, 1 - s_minus_ylen);;
int end_y = ylen + xlen - s;
return {start_x, end_x, start_y, end_y};
}

/**
* Convolve a vector with a reversed probability mass function.
*
* This function performs a discrete convolution of two vectors, where the second vector
* is assumed to be an already reversed probability mass function.
*
* @param x The input vector to be convolved.
* @param y The already reversed probability mass function vector.
* @param len The desired length of the output vector.
* @return A vector of length `len` containing the convolution result.
* @throws If `len` is not of equal length to the sum of the lengths of `x` and `y`.
*/
vector convolve_with_rev_pmf(vector x, vector y, int len) {
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;
if (xlen + ylen <= len) {
reject("convolve_with_rev_pmf: len is longer then x and y combined");
}
for (s in 1:len) {
z[s] = dot_product(
x[max(1, (s - ylen + 1)):min(s, xlen)],
y[max(1, ylen - s + 1):min(ylen, ylen + xlen - s)]
);
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;

if (xlen + ylen - 1 < len) {
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
}

if (xlen > len) {
reject("convolve_with_rev_pmf: len is shorter than x");
}

for (s in 1:xlen) {
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

if (len > xlen) {
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
return(z);
}

return z;
}


// convolve latent infections to reported (but still unobserved) cases
/**
* Convolve infections to reported cases.
*
* This function convolves a vector of infections with a reversed delay
* distribution to produce a vector of reported cases.
*
* @param infections A vector of infection counts.
* @param delay_rev_pmf A vector representing the reversed probability mass
* function of the delay distribution.
* @param seeding_time The number of initial time steps to exclude from the
* output.
* @return A vector of reported cases, starting from `seeding_time + 1`.
*/
vector convolve_to_report(vector infections,
vector delay_rev_pmf,
int seeding_time) {
int t = num_elements(infections);
vector[t - seeding_time] reports;
vector[t] unobs_reports = infections;
int delays = num_elements(delay_rev_pmf);
if (delays) {
unobs_reports = convolve_with_rev_pmf(unobs_reports, delay_rev_pmf, t);
reports = unobs_reports[(seeding_time + 1):t];
} else {
reports = infections[(seeding_time + 1):t];

if (delays == 0) {
return infections[(seeding_time + 1):t];
}
return(reports);

vector[t] unobs_reports = convolve_with_rev_pmf(infections, delay_rev_pmf, t);
return unobs_reports[(seeding_time + 1):t];
}
6 changes: 3 additions & 3 deletions inst/stan/functions/delays.stan
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ vector get_delay_rev_pmf(
pmf[1:new_len] = new_variable_pmf;
} else { // subsequent delay to be convolved
pmf[1:new_len] = convolve_with_rev_pmf(
pmf[1:current_len], reverse_mf(new_variable_pmf), new_len
pmf[1:current_len], reverse(new_variable_pmf), new_len
);
}
} else { // nonparametric
Expand All @@ -54,7 +54,7 @@ vector get_delay_rev_pmf(
pmf[1:new_len] = delay_np_pmf[start:end];
} else { // subsequent delay to be convolved
pmf[1:new_len] = convolve_with_rev_pmf(
pmf[1:current_len], reverse_mf(delay_np_pmf[start:end]), new_len
pmf[1:current_len], reverse(delay_np_pmf[start:end]), new_len
);
}
}
Expand All @@ -70,7 +70,7 @@ vector get_delay_rev_pmf(
pmf = cumulative_sum(pmf);
}
if (reverse_pmf) {
pmf = reverse_mf(pmf);
pmf = reverse(pmf);
}
return pmf;
}
Expand Down
33 changes: 0 additions & 33 deletions inst/stan/functions/pmfs.stan
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,3 @@ vector discretised_pmf(vector params, int n, int dist) {
}
return(exp(lpmf));
}

// reverse a mf
vector reverse_mf(vector pmf) {
int pmf_length = num_elements(pmf);
vector[pmf_length] rev_pmf;
for (d in 1:pmf_length) {
rev_pmf[d] = pmf[pmf_length - d + 1];
}
return rev_pmf;
}

vector rev_seq(int base, int len) {
vector[len] seq;
for (i in 1:len) {
seq[i] = len + base - i;
}
return(seq);
}

real rev_pmf_mean(vector rev_pmf, int base) {
int len = num_elements(rev_pmf);
vector[len] rev_pmf_seq = rev_seq(base, len);
return(dot_product(rev_pmf_seq, rev_pmf));
}

real rev_pmf_var(vector rev_pmf, int base, real mean) {
int len = num_elements(rev_pmf);
vector[len] rev_pmf_seq = rev_seq(base, len);
for (i in 1:len) {
rev_pmf_seq[i] = pow(rev_pmf_seq[i], 2);
}
return(dot_product(rev_pmf_seq, rev_pmf) - pow(mean, 2));
}
29 changes: 26 additions & 3 deletions tests/testthat/test-stan-convole.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
skip_on_cran()
skip_on_os("windows")

test_that("convolve can combine two pmfs as expected", {
# Test calc_conv_indices_xlen function
test_that("calc_conv_indices_xlen calculates correct indices", {
expect_equal(calc_conv_indices_xlen(1, 5, 3), c(1, 1, 3, 3))
expect_equal(calc_conv_indices_xlen(3, 5, 3), c(1, 3, 1, 3))
expect_equal(calc_conv_indices_xlen(5, 5, 3), c(3, 5, 1, 3))
})

# Test calc_conv_indices_len function
test_that("calc_conv_indices_len calculates correct indices", {
expect_equal(calc_conv_indices_len(6, 5, 3), c(4, 5, 1, 2))
expect_equal(calc_conv_indices_len(7, 5, 3), c(5, 5, 1, 1))
expect_equal(calc_conv_indices_len(8, 5, 3), c(6, 5, 1, 0))
})

test_that("convolve_with_rev_pmf can combine two pmfs as expected", {
expect_equal(
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 5),
c(0.01, 0.04, 0.18, 0.28, 0.49),
Expand All @@ -14,7 +28,7 @@ test_that("convolve can combine two pmfs as expected", {
)
})

test_that("convolve performs the same as a numerical convolution", {
test_that("convolve_with_rev_pmf performs the same as a numerical convolution", {
# Sample and analytical PMFs for two Poisson distributions
x <- rpois(100000, 3)
xpmf <- dpois(0:20, 3)
Expand All @@ -32,7 +46,7 @@ test_that("convolve performs the same as a numerical convolution", {
expect_lte(sum(abs(conv_cdf - cdf)), 0.1)
})

test_that("convolve_dot_product can combine vectors as we expect", {
test_that("convolve_with_rev_pmf can combine vectors as we expect", {
expect_equal(
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 3),
c(0.01, 0.04, 0.18),
Expand All @@ -54,3 +68,12 @@ test_that("convolve_dot_product can combine vectors as we expect", {
x
)
})

test_that("convolve_dot_product can combine two vectors where x > y and len = x", {
x <- c(1, 2, 3, 4, 5)
y <- c(1, 2, 3)
expect_equal(
convolve_with_rev_pmf(x, rev(y), 5),
c(1, 4, 10, 16, 22)
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-stan-secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ skip_on_os("windows")
# test primary reports and observations
reports <- rep(10, 20)
obs <- rep(4, 20)
delay_rev_pmf <- reverse_mf(discretised_pmf(c(log(3), 0.1), 5, 0))
delay_rev_pmf <- rev(discretised_pmf(c(log(3), 0.1), 5, 0))
scaled <- reports * 0.1
convolved <- rep(1e-5, 20) + convolve_to_report(scaled, delay_rev_pmf, 0)

Expand Down

0 comments on commit 7ff2590

Please sign in to comment.