Skip to content

Commit

Permalink
Cmdstanr option (#537)
Browse files Browse the repository at this point in the history
* add cmdstan backend

* add cmdstanr model

* generalise extract functions

* create general fit_model function

* initial values for cmdstanr

* move stan model creation to create func

* make dist_fit cmdstanr ready

* make estimate_infections cmdstanr ready

* make estimate_secondary cmdstanr ready

* make estimate_truncation cmdstanr ready

* make simulate_infections cmdstanr ready

* gitignore for binaries

* add cmdstanr as suggest

* make forecast_secondary cmdstanr ready

* update stanargs test

* make simulations work with updated options

* add globals

* tests for cmdstanr backend

* update actions

* updates in response to lintr

* don't use future_lapply for cmdstanr

* backend-specific success criteria

* use epinowcast action for installing cmdstan

* improve .gitignore for compiled stan files

* deactivate testing on windows for now

* Revert "use epinowcast action for installing cmdstan"

This reverts commit e57ae76.

* Apply suggestions from code review

Co-authored-by: Sam Abbott <[email protected]>

* match arguments in `stan_model`

* don't match args$method but explictly stop instead

* put choices in argument

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Feb 14, 2024
1 parent 669be3b commit 44d360b
Show file tree
Hide file tree
Showing 42 changed files with 958 additions and 234 deletions.
15 changes: 12 additions & 3 deletions .github/workflows/R-CMD-as-cran-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
dependencies: NA
extra-packages: |
rcmdcheck
stan-dev/cmdstanr
testthat
- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
upload-snapshots: true
upload-snapshots: true
23 changes: 14 additions & 9 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install cmdstan Linux system dependencies
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
sudo apt-get install -y libpng-dev || true
- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v2
Expand All @@ -65,8 +58,20 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
dependencies: NA
extra-packages: |
dplyr
rmarkdown
rcmdcheck
stan-dev/cmdstanr
testthat
- name: Install cmdstan
if: runner.os != 'Windows'
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/lint-only-changed-files.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
dependencies: NA
extra-packages: |
stan-dev/cmdstanr
any::gh
any::lintr
any::purrr
needs: check
- name: Add lintr options
run: |
Expand All @@ -44,4 +45,4 @@ jobs:
lintr::lint_package(exclusions = exclusions_list)
shell: Rscript {0}
env:
LINTR_ERROR_ON_LINT: true
LINTR_ERROR_ON_LINT: true
4 changes: 3 additions & 1 deletion .github/workflows/synthetic-validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
dependencies: NA
extra-packages: |
here
dplyr
tidyr
scoringutils
Expand All @@ -48,4 +50,4 @@ jobs:
with:
name: fits
retention-days: 5
path: synthetic.rds
path: synthetic.rds
15 changes: 12 additions & 3 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::covr
needs: coverage
dependencies: NA
extra-packages: |
covr
stan-dev/cmdstanr
testthat
- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- name: Test coverage
run: covr::codecov(quiet = FALSE)
shell: Rscript {0}
shell: Rscript {0}
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ vignettes/results

# unused figures
man/figures/*.png

# exclude compiled stan files
inst/stan/*
!inst/stan/*/
!inst/stan/*.stan
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Imports:
lubridate,
methods,
patchwork,
posterior,
progressr,
purrr,
R.utils (>= 2.0.0),
Expand All @@ -122,6 +123,7 @@ Imports:
truncnorm,
utils
Suggests:
cmdstanr,
covr,
dplyr,
here,
Expand All @@ -143,13 +145,15 @@ LinkingTo:
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
StanHeaders (>= 2.26.0)
Additional_repositories:
https://mc-stan.org/r-packages/
Biarch: true
Config/testthat/edition: 3
Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export(estimates_by_report_date)
export(expose_stan_fns)
export(extract_CrIs)
export(extract_inits)
export(extract_samples)
export(extract_stan_param)
export(fix_dist)
export(forecast_infections)
Expand All @@ -66,6 +67,7 @@ export(make_conf)
export(map_prob_change)
export(obs_opts)
export(opts_list)
export(package_model)
export(plot_estimates)
export(plot_summary)
export(regional_epinow)
Expand All @@ -91,6 +93,8 @@ export(setup_target_folder)
export(simulate_infections)
export(simulate_secondary)
export(stan_opts)
export(stan_sampling_opts)
export(stan_vb_opts)
export(summarise_key_measures)
export(summarise_results)
export(trunc_opts)
Expand Down Expand Up @@ -133,6 +137,7 @@ importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
importFrom(data.table,setcolorder)
importFrom(data.table,setkey)
importFrom(data.table,setnames)
importFrom(data.table,setorder)
importFrom(data.table,setorderv)
Expand Down Expand Up @@ -184,6 +189,7 @@ importFrom(lifecycle,deprecate_warn)
importFrom(lubridate,days)
importFrom(lubridate,wday)
importFrom(patchwork,plot_layout)
importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
Expand Down
31 changes: 29 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,14 @@ create_initial_conditions <- function(data) {
#'
#' @param data A list of stan data as created by [create_stan_data()]
#'
#' @param init Initial conditions passed to `{rstan}`. Defaults to "random" but
#' can also be a function (as supplied by [create_initial_conditions()]).
#' @param init Initial conditions passed to `{rstan}`. Defaults to "random"
#' (initial values randomly drawn between -2 and 2) but can also be a
#' function (as supplied by [create_initial_conditions()]).
#'
#' @param model Character, name of the model for which arguments are
#' to be created.
#' @param fixed_param Logical, defaults to `FALSE`. Should arguments be
#' created to sample from fixed parameters (used by simulation functions).
#'
#' @param verbose Logical, defaults to `FALSE`. Should verbose progress
#' messages be returned.
Expand All @@ -650,7 +656,28 @@ create_initial_conditions <- function(data) {
create_stan_args <- function(stan = stan_opts(),
data = NULL,
init = "random",
model = "estimate_infections",
fixed_param = FALSE,
verbose = FALSE) {
if (fixed_param) {
if (stan$backend == "rstan") {
stan$algorithm <- "Fixed_param"
} else if (stan$backend == "cmdstanr") {
stan$fixed_param <- TRUE
stan$adapt_delta <- NULL
stan$max_treedepth <- NULL
}
}
## generate stan model
if (is.null(stan$object)) {
stan$object <- stan_model(stan$backend, model)
stan$backend <- NULL
}
# cmdstanr doesn't have an init = "random" argument
if (is.character(init) && init == "random" &&
inherits(stan$object, "CmdStanModel")) {
init <- 2
}
# set up shared default arguments
args <- list(
data = data,
Expand Down
37 changes: 22 additions & 15 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
#' @return A stan fit of an interval censored distribution
#' @author Sam Abbott
#' @export
#' @inheritParams stan_opts
#' @examples
#' \donttest{
#' # integer adjusted exponential model
Expand All @@ -221,7 +222,8 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
#' )
#' }
dist_fit <- function(values = NULL, samples = 1000, cores = 1,
chains = 2, dist = "exp", verbose = FALSE) {
chains = 2, dist = "exp", verbose = FALSE,
backend = "rstan") {
if (samples < 1000) {
samples <- 1000
warning(sprintf("%s %s", "`samples` must be at least 1000.",
Expand All @@ -244,7 +246,7 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1,
par_sigma = numeric(0)
)

model <- stanmodels$dist_fit
model <- stan_model(backend, "dist_fit")

if (dist == "exp") {
data$dist <- 0
Expand All @@ -268,16 +270,21 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1,
}

# fit model
fit <- rstan::sampling(
model,
data = data,
iter = samples + 1000,
warmup = 1000,
control = list(adapt_delta = adapt_delta),
chains = chains,
cores = cores,
refresh = ifelse(verbose, 50, 0)
args <- create_stan_args(
stan = stan_opts(
model,
samples = samples,
warmup = 1000,
control = list(adapt_delta = adapt_delta),
chains = chains,
cores = cores,
backend = backend
),
data = data, verbose = verbose, model = "dist_fit"
)

fit <- fit_model(args, id = "dist_fit")

return(fit)
}

Expand Down Expand Up @@ -533,11 +540,11 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal",

out <- list()
if (dist == "lognormal") {
out$mean_samples <- sample(rstan::extract(fit)$mu, samples)
out$sd_samples <- sample(rstan::extract(fit)$sigma, samples)
out$mean_samples <- sample(extract(fit)$mu, samples)
out$sd_samples <- sample(extract(fit)$sigma, samples)
} else if (dist == "gamma") {
alpha_samples <- sample(rstan::extract(fit)$alpha, samples)
beta_samples <- sample(rstan::extract(fit)$beta, samples)
alpha_samples <- sample(extract(fit)$alpha, samples)
beta_samples <- sample(extract(fit)$beta, samples)
out$mean_samples <- alpha_samples / beta_samples
out$sd_samples <- sqrt(alpha_samples) / beta_samples
}
Expand Down
Loading

0 comments on commit 44d360b

Please sign in to comment.