Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed rep phi #560

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* Fixed a bug in the bounds of delays when setting initial conditions. By @sbfnk in #474.
* Added input checking to `estimate_infections()`, `estimate_secondary()`, `estimate_truncation()`, `simulate_infections()`, and `epinow()`. `check_reports_valid()` has been added to validate the reports dataset passed to these functions. Tests are added to check `check_reports_valid()`. As part of input validation, the various `*_opts()` functions now return subclasses of the same name as the functions and are tested against passed arguments to ensure the right `*_opts()` is passed to the right argument. For example, the `obs` argument in `estimate_secondary()` is expected to only receive arguments passed through `obs_opts()` and will error otherwise. By @jamesmbaazam in #476 and reviewed by @sbfnk and @seabbs.
* Added the possibility of specifying a fixed observation scaling. By @sbfnk in #550 and reviewed by @seabbs.
* Added the possibility of specifying fixed overdispersion. By @sbfnk in .

## Model changes

Expand Down
4 changes: 2 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ create_gp_data <- function(gp = gp_opts(), data) {
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family == "negbin"),
phi_mean = obs$phi[1],
phi_sd = obs$phi[2],
phi_mean = obs$phi$mean,
phi_sd = obs$phi$sd,
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
Expand Down
32 changes: 20 additions & 12 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,11 @@ gp_opts <- function(basis_prop = 0.2,
#' model. Custom settings can be supplied which override the defaults.
#' @param family Character string defining the observation model. Options are
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the mean
#' and standard deviation of the normal prior used for the observation
#' process.
#' @param phi Overdispersion parameter of the reporting process, used only if
#' `familiy` is "negbin". Can be supplied either as a single numeric value
#' (fixed overdispersion) or a list with numeric elements mean (`mean`) and
#' standard deviation (`sd`) defining a normally distributed overdispersion.
#' Defaults to a list with elements `mean = 0` and `sd = 1`.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
Expand Down Expand Up @@ -470,17 +472,14 @@ gp_opts <- function(basis_prop = 0.2,
#' # Scale reported data
#' obs_opts(scale = list(mean = 0.2, sd = 0.02))
obs_opts <- function(family = "negbin",
phi = c(0, 1),
phi = list(mean = 0, sd = 1),
weight = 1,
week_effect = TRUE,
week_length = 7,
scale = 1,
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE) {
if (length(phi) != 2 || !is.numeric(phi)) {
stop("phi be numeric and of length two")
}
na <- arg_match(na)
if (na == "accumulate") {
message(
Expand All @@ -493,6 +492,13 @@ obs_opts <- function(family = "negbin",
)
}

if (length(phi) == 2 && is.numeric(phi)) {
warning(
"Specifying `phi` as a length 2 vector is deprecated. Mean and SD ",
"should be given as list elements."
)
phi <- list(mean = phi[1], sd = phi[2])
}
obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
phi = phi,
Expand All @@ -505,11 +511,13 @@ obs_opts <- function(family = "negbin",
return_likelihood = return_likelihood
)

if (is.numeric(obs$scale)) {
obs$scale <- list(mean = obs$scale, sd = 0)
}
if (!(all(c("mean", "sd") %in% names(obs$scale)))) {
stop("If specifying a scale as list both a mean and sd are needed")
for (param in c("phi", "scale")) {
if (is.numeric(obs[[param]])) {
obs[[param]] <- list(mean = obs[[param]], sd = 0)
}
if (!(all(c("mean", "sd") %in% names(obs[[param]])))) {
stop("If specifying a ", param, " as list both a mean and sd are needed")
}
}

attr(obs, "class") <- c("obs_opts", class(obs))
Expand Down
6 changes: 4 additions & 2 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports,
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
real dispersion = 1 / pow(phi_sd > 0 ? rep_phi[model_type] : phi_mean, 2);
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
if (phi_sd > 0) {
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
}
if (weight == 1) {
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
} else {
Expand Down
10 changes: 6 additions & 4 deletions man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 16 additions & 2 deletions tests/testthat/test-create_obs_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,23 @@ test_that("create_obs_model can be used with a custom week length", {
})

test_that("create_obs_model can be used with a user set phi", {
obs <- create_obs_model(dates = dates, obs = obs_opts(phi = c(10, 0.1)))
obs <- create_obs_model(
dates = dates, obs = obs_opts(phi = list(mean = 10, sd = 0.1))
)
expect_equal(obs$phi_mean, 10)
expect_equal(obs$phi_sd, 0.1)
expect_error(obs_opts(phi = c(10)))
obs <- create_obs_model(
dates = dates,
obs = obs_opts(phi = 0.5)
)
expect_equal(obs$phi_mean, 0.5)
expect_equal(obs$phi_sd, 0)
expect_error(obs_opts(phi = c("Hi", "World")))
})

test_that("using a vector for phi in create_obs_model is deprecated", {
expect_warning(
create_obs_model(dates = dates, obs = obs_opts(phi = c(10, 0.1))),
"deprecated"
)
})
Loading