diff --git a/R/marginal_model.R b/R/marginal_model.R index b54d4c581..a546c5eb4 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -59,3 +59,66 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { is_epidist_marginal_model <- function(data) { inherits(data, "epidist_marginal_model") } + +#' Create the model-specific component of an `epidist` custom family +#' +#' @inheritParams epidist_family_model +#' @param ... Additional arguments passed to method. +#' @method epidist_family_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_family_model.epidist_marginal_model <- function( + data, family, ...) { + custom_family <- brms::custom_family( + "primarycensored_wrapper", + dpars = family$dpars, + links = c(family$link, family$other_links), + lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), + ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))), + type = "int", + loop = TRUE, + vars = "vreal1[n]" + ) + return(custom_family) +} + +#' Define the model-specific component of an `epidist` custom formula +#' +#' @inheritParams epidist_formula_model +#' @param ... Additional arguments passed to method. +#' @method epidist_formula_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_formula_model.epidist_marginal_model <- function( + data, formula, ...) { + # data is only used to dispatch on + formula <- stats::update( + formula, delay | weights(n) + vreal(pwindow) ~ . + ) + return(formula) +} + +#' @method epidist_stancode epidist_marginal_model +#' @importFrom brms stanvar +#' @family marginal_model +#' @autoglobal +#' @export +epidist_stancode.epidist_marginal_model <- function(data, ...) { + assert_epidist(data) + + stanvars_version <- .version_stanvar() + + stanvars_functions <- brms::stanvar( + block = "functions", + scode = .stan_chunk(file.path("marginal_model", "functions.stan")) + ) + + pcd_stanvars_functions <- brms::stanvar( + block = "functions", + scode = pcd_load_stan_functions() + ) + + stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions + + return(stanvars_all) +} diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R index 42f8ef731..c02820cc2 100644 --- a/tests/testthat/test-marginal_model.R +++ b/tests/testthat/test-marginal_model.R @@ -1,5 +1,45 @@ test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter. prep_obs <- as_epidist_marginal_model(sim_obs) expect_s3_class(prep_obs, "data.frame") - expect_s3_class(prep_obs, "epidist_latent_model") + expect_s3_class(prep_obs, "epidist_marginal_model") +}) + +test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect inputs", { # nolint: line_length_linter. + expect_error(as_epidist_marginal_model(list())) + expect_error(as_epidist_marginal_model(sim_obs[, 1])) +}) + +# Make this data available for other tests +family_lognormal <- epidist_family(prep_obs, family = brms::lognormal()) + +test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter. + expect_true(is_epidist_marginal_model(prep_obs)) + expect_true({ + x <- list() + class(x) <- "epidist_marginal_model" + is_epidist_marginal_model(x) + }) +}) + +test_that("is_epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_false(is_epidist_marginal_model(list())) + expect_false({ + x <- list() + class(x) <- "epidist_marginal_model_extension" + is_epidist_marginal_model(x) + }) +}) + +test_that("assert_epidist.epidist_marginal_model doesn't produce an error for correct input", { # nolint: line_length_linter. + expect_no_error(assert_epidist(prep_obs)) +}) + +test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_error(assert_epidist(list())) + expect_error(assert_epidist(prep_obs[, 1])) + expect_error({ + x <- list() + class(x) <- "epidist_marginal_model" + assert_epidist(x) + }) })