Skip to content

Commit

Permalink
Enable fixed observation scaling (#550)
Browse files Browse the repository at this point in the history
* allow fixed fraction observed

* add test

* update existing tests

* make indentation consistent

* add news item

* improve documentation of scale parameter

* add PR number

* Update NEWS.md

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Feb 19, 2024
1 parent 23229a8 commit 8a38ebb
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 36 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Changed all instances of arguments that refer to the maximum of a distribution to reflect the maximum. Previously this did, in some instance, refer to the length of the PMF. By @sbfnk in #468.
* Fixed a bug in the bounds of delays when setting initial conditions. By @sbfnk in #474.
* Added input checking to `estimate_infections()`, `estimate_secondary()`, `estimate_truncation()`, `simulate_infections()`, and `epinow()`. `check_reports_valid()` has been added to validate the reports dataset passed to these functions. Tests are added to check `check_reports_valid()`. As part of input validation, the various `*_opts()` functions now return subclasses of the same name as the functions and are tested against passed arguments to ensure the right `*_opts()` is passed to the right argument. For example, the `obs` argument in `estimate_secondary()` is expected to only receive arguments passed through `obs_opts()` and will error otherwise. By @jamesmbaazam in #476 and reviewed by @sbfnk and @seabbs.
* Added the possibility of specifying a fixed observation scaling. By @sbfnk in #550 and reviewed by @seabbs.

## Model changes

Expand Down
14 changes: 4 additions & 10 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -415,22 +415,16 @@ create_obs_model <- function(obs = obs_opts(), dates) {
phi_sd = obs$phi[2],
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
obs_scale_mean = obs$scale$mean,
obs_scale_sd = obs$scale$sd,
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)

data$day_of_week <- add_day_of_week(dates, data$week_effect)

data <- c(data, list(
obs_scale_mean = ifelse(data$obs_scale,
obs$scale$mean, 0
),
obs_scale_sd = ifelse(data$obs_scale,
obs$scale$sd, 0
)
))
return(data)
}
#' Create Stan Data Required for estimate_infections
Expand Down Expand Up @@ -614,7 +608,7 @@ create_initial_conditions <- function(data) {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale == 1) {
if (data$obs_scale_sd > 0) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
value.V1 := NULL
]
}
if (data$obs_scale == 1) {
if (data$obs_scale_sd > 0) {
out$fraction_observed <- extract_static_parameter("frac_obs", samples)
out$fraction_observed <- out$fraction_observed[, value := value.V1][,
value.V1 := NULL
Expand Down
27 changes: 14 additions & 13 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,13 @@ gp_opts <- function(basis_prop = 0.2,
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#' effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#' @param scale List, defaulting to an empty list. Should an scaling factor be
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#' @param scale Scaling factor to be applied to map latent infections (convolved
#' to date of report). Can be supplied either as a single numeric value (fixed
#' scale) or a list with numeric elements mean (`mean`) and standard deviation
#' (`sd`) defining a normally distributed scaling factor. Defaults to 1, i.e.
#' no scaling.
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
#' "missing", any NA values in the observation data set will be interpreted as
Expand Down Expand Up @@ -473,7 +474,7 @@ obs_opts <- function(family = "negbin",
weight = 1,
week_effect = TRUE,
week_length = 7,
scale = list(),
scale = 1,
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE) {
Expand Down Expand Up @@ -504,13 +505,13 @@ obs_opts <- function(family = "negbin",
return_likelihood = return_likelihood
)

if (length(obs$scale) != 0) {
scale_names <- names(obs$scale)
scale_correct <- "mean" %in% scale_names & "sd" %in% scale_names
if (!scale_correct) {
stop("If specifying a scale both a mean and sd are needed")
}
if (is.numeric(obs$scale)) {
obs$scale <- list(mean = obs$scale, sd = 0)
}
if (!(all(c("mean", "sd") %in% names(obs$scale)))) {
stop("If specifying a scale as list both a mean and sd are needed")
}

attr(obs, "class") <- c("obs_opts", class(obs))
return(obs)
}
Expand Down
6 changes: 3 additions & 3 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ parameters{
array[delay_n_p] real delay_mean; // mean of delays
array[delay_n_p] real<lower = 0> delay_sd; // sd of delays
simplex[week_effect] day_of_week_simplex;// day of week reporting effect
array[obs_scale] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[obs_scale_sd > 0 ? 1 : 0] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
}

Expand Down Expand Up @@ -105,7 +105,7 @@ transformed parameters {
}
// scaling of reported cases by fraction observed
if (obs_scale) {
reports = scale_obs(reports, frac_obs[1]);
reports = scale_obs(reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean);
}
// truncate near time cases to observed reports
if (trunc_id) {
Expand Down Expand Up @@ -142,7 +142,7 @@ model {
);
}
// prior observation scaling
if (obs_scale) {
if (obs_scale_sd > 0) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
}
// observed reports from mean of reports (update likelihood)
Expand Down
11 changes: 6 additions & 5 deletions man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions tests/testthat/test-create_obs_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@ test_that("create_obs_model works with default settings", {
expect_equal(length(obs), 12)
expect_equal(names(obs), c(
"model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight",
"obs_scale", "accumulate", "likelihood", "return_likelihood",
"day_of_week", "obs_scale_mean",
"obs_scale_sd"
"obs_scale", "obs_scale_mean", "obs_scale_sd", "accumulate",
"likelihood", "return_likelihood", "day_of_week"
))
expect_equal(obs$model_type, 1)
expect_equal(obs$week_effect, 7)
expect_equal(obs$obs_scale, 0)
expect_equal(obs$likelihood, 1)
expect_equal(obs$return_likelihood, 0)
expect_equal(obs$day_of_week, c(7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7))
expect_equal(obs$obs_scale_mean, 0)
expect_equal(obs$obs_scale_mean, 1)
expect_equal(obs$obs_scale_sd, 0)
})

Expand All @@ -34,6 +33,15 @@ test_that("create_obs_model can be used with a scaling", {
expect_equal(obs$obs_scale_sd, 0.01)
})

test_that("create_obs_model can be used with fixed scaling", {
obs <- create_obs_model(
dates = dates,
obs = obs_opts(scale = 0.4)
)
expect_equal(obs$obs_scale_mean, 0.4)
expect_equal(obs$obs_scale_sd, 0)
})

test_that("create_obs_model can be used with no week effect", {
obs <- create_obs_model(dates = dates, obs = obs_opts(week_effect = FALSE))
expect_equal(obs$week_effect, 1)
Expand Down

0 comments on commit 8a38ebb

Please sign in to comment.