Skip to content

Commit

Permalink
Almost working with "import all functions" strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Nov 15, 2024
1 parent 56b9a76 commit 473ddc1
Show file tree
Hide file tree
Showing 2 changed files with 856 additions and 17 deletions.
29 changes: 12 additions & 17 deletions inst/cohort-scratch.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ summary(fit_direct_weighted)

lognormal <- brms::lognormal()

primarycensored_lognormal_uniform_lcdf <- brms::custom_family(
"primarycensored_lognormal_uniform_lcdf",
primarycensored_lognormal_uniform_lpmf <- brms::custom_family(
"primarycensored_lognormal_uniform",
dpars = lognormal$dpar,
links = c(lognormal$link, lognormal$link_sigma),
type = lognormal$type,
type = "int",
loop = TRUE,
vars = c("pwindow", "vreal1[n]")
)

primarycensored_lognormal_uniform_lcdf_file <- file.path(
tempdir(), "primarycensored_lognormal_uniform_lcdf.stan"
vars = c("vreal1[n]", "pwindow[n]")
)

data <- cohort_obs |>
Expand All @@ -68,14 +64,10 @@ data <- cohort_obs |>
pwindow = 1,
q = pmax(d - pwindow, 0)
)

pcd_function <- pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf")
pcd_function <- sub(pattern = "array\\[\\] real params", "real mu, real sigma", pcd_function)
pcd_function <- gsub("\\s*real mu = params\\[1\\];\\n\\s*real sigma = params\\[2\\];\\n", "", pcd_function)

stanvars_functions <- brms::stanvar(
block = "functions",
scode = pcd_function
scode = .stan_chunk("cohort_model/primarycensored-edit.stan")
)

# stanvars_tparameters <- brms::stanvar(
Expand All @@ -98,16 +90,19 @@ stanvars_data <- brms::stanvar(

stanvars_all <- stanvars_functions + stanvars_data

brms::make_stancode(
stancode <- brms::make_stancode(
formula = d | weights(n) + vreal(q) ~ 1,
family = primarycensored_lognormal_uniform_lcdf,
family = primarycensored_lognormal_uniform_lpmf,
data = data,
stanvars = stanvars_all,
)

model <- rstan::stan_model(model_code = stancode)

fit_pcd <- brms::brm(
formula = d | weights(n), vreal(q) ~ 1,
family = primarycensored_lognormal_uniform_lcdf,
formula = d | weights(n) + vreal(q) ~ 1,
family = primarycensored_lognormal_uniform_lpmf,
data = data,
stanvars = stanvars_all,
backend = "cmdstanr"
)
Loading

0 comments on commit 473ddc1

Please sign in to comment.