Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #267: Refactor to allow custom event priors and marginalise the latent likelihood #474

Merged
merged 28 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0fc1fa9
first pass at refactoring latent model to use window formulas
seabbs Nov 24, 2024
4b3195c
add docs to stan function
seabbs Nov 25, 2024
190f2e1
check getting started -drive by fix plotting
seabbs Nov 25, 2024
4dd9e4f
update approach to handling formulas
seabbs Nov 25, 2024
8ff54fe
get reparameterisation from brms itself vs enforcing manual declaration
seabbs Nov 25, 2024
bf4a910
work on regexing:
seabbs Nov 25, 2024
269e8f6
test manually setting new priors
seabbs Nov 25, 2024
eb7d502
fix .replace_prior
seabbs Nov 25, 2024
4b2ae8e
reset for pause
seabbs Nov 25, 2024
8ece45f
add back in lower bounds
seabbs Nov 25, 2024
03d24f0
revert pass in via formula
seabbs Nov 26, 2024
7f4fa97
add custom priors pass in
seabbs Nov 26, 2024
4381326
write priors down more neatly
seabbs Nov 26, 2024
b0b4834
add manual prior mode and optout
seabbs Nov 26, 2024
856aacf
clean up easy test failures
seabbs Nov 26, 2024
57d2dae
use marginalised log likelihood
seabbs Nov 27, 2024
23dea17
debug marginalised likelihood
seabbs Nov 27, 2024
237f48a
workaround for liklihood vectorisation
seabbs Nov 27, 2024
cb0c403
further increase prior complexity options
seabbs Nov 27, 2024
750ea52
update prior ordering
seabbs Nov 27, 2024
b2578a6
catch printing issue for .replace_prior
seabbs Nov 27, 2024
ead5f67
add news iteem
seabbs Nov 27, 2024
fb50250
add PR links
seabbs Nov 27, 2024
5beebde
speeed up test
seabbs Nov 27, 2024
bafa1cf
code read through
seabbs Nov 27, 2024
df49ef0
clean up precommit
seabbs Nov 27, 2024
c425116
turn off priorsense to check theory its numerical instability for ext…
seabbs Nov 27, 2024
8711ddc
review comments
seabbs Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ repos:
- id: check-added-large-files
args: ['--maxkb=200']
- id: end-of-file-fixer
exclude: ^tests/testthat/_snaps
- repo: local
hooks:
- id: forbid-to-commit
Expand Down
12 changes: 9 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ S3method(assert_epidist,epidist_linelist_data)
S3method(assert_epidist,epidist_naive_model)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_model)
S3method(epidist_family_param,default)
S3method(epidist_family_prior,default)
S3method(epidist_family_prior,lognormal)
S3method(epidist_family_reparam,default)
S3method(epidist_family_reparam,gamma)
S3method(epidist_formula_model,default)
S3method(epidist_formula_model,epidist_latent_model)
S3method(epidist_model_prior,default)
S3method(epidist_model_prior,epidist_latent_model)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_model)
export(Gamma)
Expand All @@ -33,8 +33,8 @@ export(epidist)
export(epidist_diagnostics)
export(epidist_family)
export(epidist_family_model)
export(epidist_family_param)
export(epidist_family_prior)
export(epidist_family_reparam)
export(epidist_formula)
export(epidist_formula_model)
export(epidist_gen_posterior_epred)
Expand All @@ -57,16 +57,20 @@ export(simulate_secondary)
export(simulate_uniform_cases)
export(weibull)
import(ggplot2)
importFrom(brms,as.brmsprior)
importFrom(brms,bf)
importFrom(brms,lognormal)
importFrom(brms,make_stancode)
importFrom(brms,prior)
importFrom(brms,set_prior)
importFrom(brms,stanvar)
importFrom(brms,weibull)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_date)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_integer)
importFrom(checkmate,assert_integerish)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_true)
Expand All @@ -75,12 +79,14 @@ importFrom(cli,cli_alert_info)
importFrom(cli,cli_inform)
importFrom(cli,cli_warn)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,mutate)
importFrom(dplyr,select)
importFrom(lubridate,days)
importFrom(lubridate,is.timepoint)
importFrom(purrr,map_dbl)
importFrom(stats,Gamma)
importFrom(stats,as.formula)
importFrom(stats,setNames)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Development version of `epidist`.
## Package

- Remove the default method for `epidist()`. See #473.
- Added `enforce_presence` argument to `epidist_prior()` to allow for priors to be
specified if they do not match existing parameters. See #474.
- Added a `merge` argument to `epidist_prior()` to allow for not merging user and package priors. See #474.
- Added user settable primary event priors to the latent model. See #474.
- Added a marginalised likelihood to the latent model. See #474.
- Generalised the stan reparametrisation feature to work across all distributions without manual specification by generating stan code with `brms` and then extracting the reparameterisation. See #474.

## Documentation

Expand Down
12 changes: 10 additions & 2 deletions R/epidist.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
#' reexported as part of `epidist`.
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions. These priors are passed to [epidist_prior()] in the
#' `prior` argument.
#' `prior` argument. Some models have default priors that are automatically
#' added (see [epidist_model_prior()]). These can be merged with user-provided
#' priors using the `merge_priors` argument.
#' @param merge_priors If `TRUE` then merge user priors with default priors, if
#' `FALSE` only use user priors. Defaults to `TRUE`. This may be useful if
#' the built in approaches for merging priors are not flexible enough for a
#' particular use case.
#' @param fn The internal function to be called. By default this is
#' [brms::brm()] which performs inference for the specified model. Other options
#' are [brms::make_stancode()] which returns the Stan code for the specified
Expand All @@ -25,14 +31,16 @@
#' @export
epidist <- function(data, formula = mu ~ 1,
family = lognormal(), prior = NULL,
merge_priors = TRUE,
fn = brms::brm, ...) {
assert_epidist(data)
epidist_family <- epidist_family(data, family)
epidist_formula <- epidist_formula(
data = data, family = epidist_family, formula = formula
)
epidist_prior <- epidist_prior(
data = data, family = epidist_family, formula = epidist_formula, prior
data = data, family = epidist_family, formula = epidist_formula, prior,
merge = merge_priors
)
epidist_stancode <- epidist_stancode(
data = data, family = epidist_family, formula = epidist_formula
Expand Down
76 changes: 60 additions & 16 deletions R/family.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ epidist_family <- function(data, family = lognormal(), ...) {
family <- .add_dpar_info(family)
custom_family <- epidist_family_model(data, family, ...)
class(custom_family) <- c(family$family, class(custom_family))
custom_family <- epidist_family_reparam(custom_family)
custom_family <- epidist_family_param(custom_family)
return(custom_family)
}

Expand Down Expand Up @@ -43,29 +43,73 @@ epidist_family_model.default <- function(data, family, ...) {
#' Reparameterise an `epidist` family to align `brms` and Stan
#'
#' @inheritParams epidist_family
#' @rdname epidist_family_reparam
#' @rdname epidist_family_param
#' @family family
#' @export
epidist_family_reparam <- function(family, ...) {
UseMethod("epidist_family_reparam")
epidist_family_param <- function(family, ...) {
seabbs marked this conversation as resolved.
Show resolved Hide resolved
UseMethod("epidist_family_param")
}

#' Default method for families which do not require a reparameterisation
#'
#' @inheritParams epidist_family
#' @family family
#' @export
epidist_family_reparam.default <- function(family, ...) {
family$reparam <- family$dpars
return(family)
}

#' Reparameterisation for the gamma family
#' This function extracts the Stan parameterisation for a given brms family by
#' creating a dummy model and parsing its Stan code. It looks for the log
seabbs marked this conversation as resolved.
Show resolved Hide resolved
#' probability density function (lpdf) call in the Stan code and extracts the
#' parameter order used by Stan. This is needed because brms and Stan may use
#' different parameter orderings for the same distribution.
#'
#' @param family A brms family object containing at minimum a `family` element
#' specifying the distribution family name
#' @param ... Additional arguments passed to methods (not used)
#'
#' @details
#' The function works by:
#' 1. Creating a minimal dummy model using the specified family
#' 2. Extracting the Stan code for this model
#' 3. Finding the lpdf function call for the family
#' 4. Parsing out the parameter ordering used in Stan
#' 5. Adding this as the `param` element to the family object
#'
#' @return The input family object with an additional `param` element containing
#' the Stan parameter ordering as a string
#'
#' @inheritParams epidist_family
#' @family family
#' @importFrom brms make_stancode
#' @importFrom cli cli_abort
#' @export
epidist_family_reparam.gamma <- function(family, ...) {
family$reparam <- c("shape", "shape ./ mu") # nolint
epidist_family_param.default <- function(family, ...) {
df <- data.frame(y = c(1, 2))
dummy_mdl <- make_stancode(y ~ 1, data = df, family = class(family)[1])

# get the lowered family name
family_name <- tolower(class(family)[1])

# Extract the Stan parameterisation from the dummy model code
lpdf_pattern <- paste0(
"target \\+= ", family_name, "_(lpdf|lpmf)\\(Y \\| ([^)]+)\\)" # nolint
)
lpdf_match <- regexpr(lpdf_pattern, dummy_mdl)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
reparam <- if (lpdf_match > 0) {
matches <- unlist(regmatches(dummy_mdl, lpdf_match))
mu_matches <- matches[grepl("mu", matches, fixed = TRUE)]
if (length(mu_matches) > 1) {
cli_abort("Multiple Stan parameterisations found with 'mu' parameter.")
} else if (length(mu_matches) == 0) {
cli_abort("No Stan parameterisation found with 'mu' parameter.")
}
match_str <- mu_matches[1]
param <- sub(
paste0(
"target \\+= ", family_name, "_(lpdf|lpmf)\\(Y \\| " # nolint
), "",
match_str
)
param <- sub(")", "", param, fixed = TRUE)
family$param <- param
} else {
cli_abort(
"Unable to extract Stan parameterisation for {family_name}."
)
}
return(family)
}
74 changes: 72 additions & 2 deletions R/gen.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,75 @@
#' Create a function to calculate the marginalised log likelihood for double
#' censored and truncated delay distributions
#'
#' This function creates a log likelihood function that calculates the marginal
#' likelihood for a single observation by integrating over the latent primary
#' and secondary event windows. The integration is performed numerically using
#' [primarycensored::dpcens()] which handles the double censoring and truncation
#' of the delay distribution.
#'
#' The marginal likelihood accounts for uncertainty in both the primary and
#' secondary event windows by integrating over their possible values, weighted
#' by their respective uniform distributions.
#'
#' @seealso [brms::log_lik()] for details on the brms log likelihood interface.
#'
#' @inheritParams epidist_family
#'
#' @return A function that calculates the marginal log likelihood for a single
#' observation. The prep object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @family gen
#' @autoglobal
#' @importFrom purrr map_dbl
epidist_gen_log_lik <- function(family) {
# Get internal brms log_lik function
log_lik_brms <- .get_brms_fn("log_lik", family)

.log_lik <- function(i, prep) {
y <- prep$data$Y[i]
relative_obs_time <- prep$data$vreal1[i]
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

# make the prep object censored
# -1 here is equivalent to right censored in brms
prep$data$cens <- -1
seabbs marked this conversation as resolved.
Show resolved Hide resolved

# Calculate density for each draw using primarycensored::dpcens()
lpdf <- purrr::map_dbl(seq_len(prep$ndraws), function(draw) {
# Define pdist function that filters to current draw
pdist_draw <- function(q, i, prep, ...) {
purrr::map_dbl(q, function(x) {
prep$data$Y <- rep(x, length(prep$data$Y))
ll <- exp(log_lik_brms(i, prep)[draw])
return(ll)
})
}

primarycensored::dpcens(
x = y,
pdist = pdist_draw,
i = i,
prep = prep,
pwindow = pwindow,
swindow = swindow,
D = relative_obs_time,
dprimary = stats::dunif,
log = TRUE
)
})
lpdf <- brms:::log_lik_weight(lpdf, i = i, prep = prep) # nolint
seabbs marked this conversation as resolved.
Show resolved Hide resolved
return(lpdf)
}

return(.log_lik)
}

#' Create a function to draw from the posterior predictive distribution for a
#' latent model
#' double censored and truncated delay distribution
#'
#' This function creates a function that draws from the posterior predictive
#' distribution for a latent model using [primarycensored::rpcens()] to handle
Expand Down Expand Up @@ -49,7 +119,7 @@ epidist_gen_posterior_predict <- function(family) {
}

#' Create a function to draw from the expected value of the posterior predictive
#' distribution for a latent model
#' distribution for a model
#'
#' This function creates a function that calculates the expected value of the
#' posterior predictive distribution for a latent model. The returned function
Expand Down
1 change: 1 addition & 0 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ utils::globalVariables(c(
"samples", # <epidist_diagnostics>
"woverlap", # <epidist_stancode.epidist_latent_model>
"rlnorm", # <simulate_secondary>
"fix", # <.replace_prior>
"prior_new", # <.replace_prior>
"source_new", # <.replace_prior>
NULL
Expand Down
Loading
Loading