Skip to content

Commit

Permalink
use capture.output in more places
Browse files Browse the repository at this point in the history
  • Loading branch information
jgabry committed Nov 22, 2024
1 parent 4b80205 commit 3af034b
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 65 deletions.
2 changes: 2 additions & 0 deletions tests/testthat/test-fit-init.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set_cmdstan_path()
data_list_schools <- testing_data("schools")
data_list_logistic <- testing_data("logistic")
test_inits <- function(mod, fit_init, data_list = NULL) {
utils::capture.output({
fit_sample <- mod$sample(data = data_list, chains = 1, init = fit_init,
iter_sampling = 100, iter_warmup = 100, refresh = 0, seed = 1234)
fit_sample_multi <- mod$sample(data = data_list, chains = 5, init = fit_init,
Expand All @@ -20,6 +21,7 @@ test_inits <- function(mod, fit_init, data_list = NULL) {
draws = posterior::as_draws_rvars(fit_init$draws())
fit_sample_draws <- mod$sample(data = data_list, chains = 1, init = draws,
iter_sampling = 100, iter_warmup = 100, refresh = 0, seed = 1234)
})
return(0)
}

Expand Down
34 changes: 21 additions & 13 deletions tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -454,22 +454,26 @@ test_that("draws are returned for model with spaces", {
test_that("sampling with inits works with include_paths", {
stan_program_w_include <- testing_stan_file("bernoulli_include")
exe <- cmdstan_ext(strip_ext(stan_program_w_include))
if(file.exists(exe)) {
if (file.exists(exe)) {
file.remove(exe)
}

mod_w_include <- cmdstan_model(stan_file = stan_program_w_include, quiet = FALSE,
include_paths = test_path("resources", "stan"))
mod_w_include <- cmdstan_model(stan_file = stan_program_w_include,
include_paths = test_path("resources", "stan"))

data_list <- list(N = 10, y = c(0,1,0,0,0,0,0,0,0,1))

fit <- mod_w_include$sample(
data = data_list,
seed = 123,
chains = 4,
parallel_chains = 4,
refresh = 500,
init = list(list(theta = 0.25), list(theta = 0.25), list(theta = 0.25), list(theta = 0.25))
expect_no_error(
fit <- mod_w_include$sample(
data = data_list,
seed = 123,
chains = 4,
parallel_chains = 4,
refresh = 500,
init = list(list(theta = 0.25),
list(theta = 0.25),
list(theta = 0.25),
list(theta = 0.25))
)
)
})

Expand Down Expand Up @@ -548,8 +552,12 @@ test_that("code() warns if model not created with Stan file", {
stan_program <- testing_stan_file("bernoulli")
mod <- testing_model("bernoulli")
mod_exe <- cmdstan_model(exe_file = mod$exe_file())
fit_exe <- mod_exe$sample(data = list(N = 10, y = c(0, 1, 0, 1, 0, 1, 0, 1, 0, 1)),
refresh = 0)
utils::capture.output(
fit_exe <- mod_exe$sample(
data = list(N = 10, y = c(0, 1, 0, 1, 0, 1, 0, 1, 0, 1)),
refresh = 0
)
)
expect_warning(
expect_null(fit_exe$code()),
"'$code()' will return NULL because the 'CmdStanModel' was not created with a Stan file",
Expand Down
18 changes: 12 additions & 6 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ stan_prog <- paste(function_decl,
model <- write_stan_file(stan_prog)
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)
utils::capture.output(
fit <- mod$sample(data = data_list)
)


test_that("Functions can be exposed in model object", {
expect_no_error(mod$expose_functions(verbose = TRUE))
expect_no_error(mod$expose_functions())
})


Expand Down Expand Up @@ -260,7 +262,7 @@ test_that("Functions handle complex types correctly", {
})

test_that("Functions can be exposed in fit object", {
fit$expose_functions(verbose = TRUE)
fit$expose_functions()

expect_equal(
fit$functions$rtn_vec(c(1,2,3,4)),
Expand All @@ -284,7 +286,9 @@ test_that("Compiled functions can be copied to global environment", {

test_that("Functions can be compiled with model", {
mod <- cmdstan_model(model, force_recompile = TRUE, compile_standalone = TRUE)
fit <- mod$sample(data = data_list)
utils::capture.output(
fit <- mod$sample(data = data_list)
)

expect_message(
fit$expose_functions(),
Expand Down Expand Up @@ -344,9 +348,11 @@ test_that("rng functions can be exposed", {
model <- write_stan_file(stan_prog)
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)
utils::capture.output(
fit <- mod$sample(data = data_list)
)

fit$expose_functions(verbose = TRUE)
fit$expose_functions()
set.seed(10)
res1_1 <- fit$functions$wrap_normal_rng(5,10)
res2_1 <- fit$functions$wrap_normal_rng(5,10)
Expand Down
24 changes: 16 additions & 8 deletions tests/testthat/test-model-init.R
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,12 @@ test_that("Initial values for single-element containers treated correctly", {
"
mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE)
expect_no_error(
fit <- mod$sample(
data = list(y_mean = 0),
init = list(list(y = c(0))),
chains = 1
utils::capture.output(
fit <- mod$sample(
data = list(y_mean = 0),
init = list(list(y = c(0))),
chains = 1
)
)
)
})
Expand All @@ -331,7 +333,13 @@ test_that("Pathfinder inits do not drop dimensions", {
"
mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE)
data <- list(N = 100, y = rnorm(100))
pf <- mod$pathfinder(data = data, psis_resample = FALSE)
expect_no_error(fit <- mod$sample(data = data, init = pf, chains = 1,
iter_warmup = 100, iter_sampling = 100))
})
utils::capture.output(
pf <- mod$pathfinder(data = data, psis_resample = FALSE)
)
expect_no_error(
utils::capture.output(
fit <- mod$sample(data = data, init = pf, chains = 1,
iter_warmup = 100, iter_sampling = 100)
)
)
})
25 changes: 17 additions & 8 deletions tests/testthat/test-model-laplace.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ test_that("laplace() runs when all arguments specified validly", {
})

test_that("laplace() all valid 'mode' inputs give same results", {
mode <- mod$optimize(data = data_list, jacobian = TRUE, seed = 100, refresh = 0)
fit1 <- mod$laplace(data = data_list, mode = mode, seed = 100, refresh = 0)
fit2 <- mod$laplace(data = data_list, mode = mode$output_files(), seed = 100, refresh = 0)
fit3 <- mod$laplace(data = data_list, mode = NULL, seed = 100, refresh = 0)
utils::capture.output({
mode <- mod$optimize(data = data_list, jacobian = TRUE, seed = 100, refresh = 0)
fit1 <- mod$laplace(data = data_list, mode = mode, seed = 100, refresh = 0)
fit2 <- mod$laplace(data = data_list, mode = mode$output_files(), seed = 100, refresh = 0)
fit3 <- mod$laplace(data = data_list, mode = NULL, seed = 100, refresh = 0)
})

expect_is(fit1, "CmdStanLaplace")
expect_is(fit2, "CmdStanLaplace")
Expand All @@ -85,17 +87,22 @@ test_that("laplace() all valid 'mode' inputs give same results", {
})

test_that("laplace() allows choosing number of draws", {
fit <- mod$laplace(data = data_list, draws = 10, refresh = 0)
utils::capture.output({
fit <- mod$laplace(data = data_list, draws = 10, refresh = 0)
fit2 <- mod$laplace(data = data_list, draws = 100, refresh = 0)
})

expect_equal(fit$metadata()$draws, 10)
expect_equal(posterior::ndraws(fit$draws()), 10)

fit2 <- mod$laplace(data = data_list, draws = 100, refresh = 0)
expect_equal(fit2$metadata()$draws, 100)
expect_equal(posterior::ndraws(fit2$draws()), 100)
})

test_that("laplace() errors if jacobian arg doesn't match what optimize used", {
fit <- mod$optimize(data = data_list, jacobian = FALSE, refresh = 0)
utils::capture.output(
fit <- mod$optimize(data = data_list, jacobian = FALSE, refresh = 0)
)
expect_error(
mod$laplace(data = data_list, mode = fit, jacobian = TRUE),
"'jacobian' argument to optimize and laplace must match"
Expand All @@ -107,7 +114,9 @@ test_that("laplace() errors if jacobian arg doesn't match what optimize used", {
})

test_that("laplace() errors with bad combinations of arguments", {
fit <- mod$optimize(data = data_list, jacobian = TRUE, refresh = 0)
utils::capture.output(
fit <- mod$optimize(data = data_list, jacobian = TRUE, refresh = 0)
)
expect_error(
mod$laplace(data = data_list, mode = mod, opt_args = list(iter = 10)),
"Cannot specify both 'opt_args' and 'mode' arguments."
Expand Down
57 changes: 39 additions & 18 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ skip_if(os_is_wsl())
set_cmdstan_path()
mod <- cmdstan_model(testing_stan_file("bernoulli_log_lik"), force_recompile = TRUE)
data_list <- testing_data("bernoulli")
fit <- mod$sample(data = data_list, chains = 1, refresh = 0)
utils::capture.output(
fit <- mod$sample(data = data_list, chains = 1, refresh = 0)
)

test_that("Model methods automatically initialise when needed", {
expect_no_error(fit$log_prob(unconstrained_variables=c(0.1)))
Expand Down Expand Up @@ -59,7 +61,9 @@ test_that("Model methods environments are independent", {
data_list_2 <- data_list
data_list_2$N <- 20
data_list_2$y <- c(data_list$y, data_list$y)
fit_2 <- mod$sample(data = data_list_2, chains = 1)
utils::capture.output(
fit_2 <- mod$sample(data = data_list_2, chains = 1)
)
fit_2$init_model_methods()

expect_equal(fit$log_prob(unconstrained_variables=c(0.1)), -8.6327599208828509347)
Expand Down Expand Up @@ -90,8 +94,10 @@ test_that("methods error for incorrect inputs", {

logistic_mod <- cmdstan_model(testing_stan_file("logistic"), force_recompile = TRUE)
logistic_data_list <- testing_data("logistic")
logistic_fit <- logistic_mod$sample(data = logistic_data_list, chains = 1)
logistic_fit$init_model_methods(verbose = TRUE)
utils::capture.output(
logistic_fit <- logistic_mod$sample(data = logistic_data_list, chains = 1)
)
logistic_fit$init_model_methods()

expect_error(
logistic_fit$unconstrain_variables(list(alpha = 0.5)),
Expand All @@ -104,7 +110,9 @@ test_that("Methods error with already-compiled model", {
precompile_mod <- testing_model("bernoulli")
mod <- testing_model("bernoulli")
data_list <- testing_data("bernoulli")
fit <- mod$sample(data = data_list, chains = 1)
utils::capture.output(
fit <- mod$sample(data = data_list, chains = 1)
)
expect_error(
fit$init_model_methods(),
"Model methods cannot be used with a pre-compiled Stan executable, the model must be compiled again",
Expand All @@ -116,7 +124,9 @@ test_that("Methods can be compiled with model", {
mod <- cmdstan_model(testing_stan_file("bernoulli"),
force_recompile = TRUE,
compile_model_methods = TRUE)
fit <- mod$sample(data = data_list, chains = 1)
utils::capture.output(
fit <- mod$sample(data = data_list, chains = 1)
)

lp <- fit$log_prob(unconstrained_variables=c(0.6))
expect_equal(lp, -10.649855405830624733)
Expand Down Expand Up @@ -156,7 +166,9 @@ test_that("unconstrain_variables correctly handles zero-length containers", {
mod <- cmdstan_model(write_stan_file(model_code),
force_recompile = TRUE,
compile_model_methods = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 1)
utils::capture.output(
fit <- mod$sample(data = list(N = 0), chains = 1)
)
unconstrained <- fit$unconstrain_variables(variables = list(x = 5))
expect_equal(unconstrained, 5)
})
Expand All @@ -179,21 +191,23 @@ test_that("unconstrain_draws returns correct values", {
mod <- cmdstan_model(write_stan_file(model_code),
compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 2, save_warmup = TRUE)
fit_no_warmup <- mod$sample(data = list(N = 0), chains = 2)
utils::capture.output({
fit <- mod$sample(data = list(N = 0), chains = 2, save_warmup = TRUE)
fit_no_warmup <- mod$sample(data = list(N = 0), chains = 2)
})

x_draws <- fit$draws(format = "draws_df")$x
x_draws_warmup <- fit$draws(format = "draws_df", inc_warmup = TRUE)$x

# Unconstrain all internal draws
unconstrained_internal_draws <- fit$unconstrain_draws()
unconstrained_internal_draws_warmup <- fit$unconstrain_draws(inc_warmup = TRUE)
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_internal_draws))
expect_equal(as.numeric(x_draws_warmup), as.numeric(unconstrained_internal_draws_warmup))

expect_error({unconstrained_internal_draws <- fit_no_warmup$unconstrain_draws(inc_warmup = TRUE)},
"Warmup draws were requested from a fit object without them! Please rerun the model with save_warmup = TRUE.")

# Unconstrain external CmdStan CSV files
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
unconstrained_csv_warmup <- fit$unconstrain_draws(files = fit$output_files(),
Expand All @@ -204,7 +218,7 @@ test_that("unconstrain_draws returns correct values", {
# Unconstrain existing draws object
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_draws))

expect_message(fit$unconstrain_draws(draws = fit$draws(), inc_warmup = TRUE),
"'inc_warmup' cannot be used with a draws object. Ignoring.")

Expand All @@ -224,7 +238,9 @@ test_that("unconstrain_draws returns correct values", {
mod <- cmdstan_model(write_stan_file(model_code),
compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 2)
utils::capture.output(
fit <- mod$sample(data = list(N = 0), chains = 2)
)

x_draws <- fit$draws(format = "draws_df")$x

Expand All @@ -241,10 +257,13 @@ test_that("unconstrain_draws returns correct values", {
})

test_that("Model methods can be initialised for models with no data", {

stan_file <- write_stan_file("parameters { real x; } model { x ~ std_normal(); }")
mod <- cmdstan_model(stan_file, compile_model_methods = TRUE, force_recompile = TRUE)
expect_no_error(fit <- mod$sample())
expect_no_error(
utils::capture.output(
fit <- mod$sample()
)
)
expect_equal(fit$log_prob(5), -12.5)
})

Expand All @@ -268,8 +287,10 @@ test_that("Variable skeleton returns correct dimensions for matrices", {
force_recompile = TRUE)
N <- 4
K <- 3
fit <- mod$sample(data = list(N = N, K = K), chains = 1,
iter_warmup = 1, iter_sampling = 5)
utils::capture.output(
fit <- mod$sample(data = list(N = N, K = K), chains = 1,
iter_warmup = 1, iter_sampling = 5)
)

target_skeleton <- list(
x_real = array(0, dim = 1),
Expand Down
7 changes: 4 additions & 3 deletions tests/testthat/test-model-optimize.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ test_that("optimize() method runs when the stan file is removed", {
})

test_that("optimize() recognizes new jacobian argument", {
fit <- mod$optimize(data = data_list, jacobian = FALSE)
utils::capture.output({
fit <- mod$optimize(data = data_list, jacobian = FALSE)
fit2 <- mod$optimize(data = data_list, jacobian = TRUE)
})
expect_equal(fit$metadata()$jacobian, 0)

fit2 <- mod$optimize(data = data_list, jacobian = TRUE)
expect_equal(fit2$metadata()$jacobian, 1)
})
18 changes: 10 additions & 8 deletions tests/testthat/test-model-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,17 @@ test_that("Errors are suppressed with show_exceptions", {
"
errmod <- cmdstan_model(write_stan_file(errmodcode), force_recompile = TRUE)

expect_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1)),
"Chain 1 Exception: vector[uni] assign: accessing element out of range",
fixed = TRUE
)
expect_sample_output(
expect_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1)),
"Chain 1 Exception: vector[uni] assign: accessing element out of range",
fixed = TRUE
))

expect_no_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1, show_exceptions = FALSE))
)
expect_sample_output(
expect_no_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1, show_exceptions = FALSE))
))
})

test_that("All output can be suppressed by show_messages", {
Expand Down
Loading

0 comments on commit 3af034b

Please sign in to comment.