Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

179: make D real #180

Merged
merged 7 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Development release.

- Added a missing `@family` tag to the `pcens` functions. This omission resulted in the Weibull analytical solution not being visible in the package documentation.
- Changed a call to `size()` to use `num_elements()` instead as an underlying type conversion was causing issues on some platforms.
- Changed `D` to be of type real in `pcens_model.stan` in order to support infinite `relative_obs_time`.

# primarycensored 1.0.0

Expand Down
4 changes: 2 additions & 2 deletions inst/stan/pcens_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ functions {

real partial_sum(array[] int dummy, int start, int end,
array[] int d, array[] int d_upper, array[] int n,
array[] int pwindow, array[] int D,
array[] int pwindow, data array[] real D,
int dist_id, array[] real params,
int primary_id, array[] real primary_params) {
real partial_target = 0;
Expand All @@ -27,7 +27,7 @@ data {
array[N] int<lower=0> d_upper; // observed delays upper bound
array[N] int<lower=0> n; // number of occurrences for each delay
array[N] int<lower=0> pwindow; // primary censoring window
array[N] int<lower=0> D; // maximum delay
array[N] real<lower=0> D; // maximum delay
int<lower=1, upper=17> dist_id; // distribution identifier
int<lower=1, upper=2> primary_id; // primary distribution identifier
int<lower=0> n_params; // number of distribution parameters
Expand Down
69 changes: 69 additions & 0 deletions tests/testthat/test-pcd_cmdstan_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,72 @@ test_that(
)
}
)

test_that("pcd_cmdstan_model recovers true values with no bound on D", {
# Simulate data
set.seed(123)
n <- 2000
true_meanlog <- 1.5
true_sdlog <- 0.5

simulated_delays <- rprimarycensored(
n = n,
rdist = rlnorm,
meanlog = true_meanlog,
sdlog = true_sdlog,
pwindow = 1,
D = Inf
)

simulated_data <- data.frame(
delay = simulated_delays,
delay_upper = simulated_delays + 1,
pwindow = 1,
relative_obs_time = Inf
)

delay_counts <- simulated_data |>
dplyr::summarise(
n = dplyr::n(),
.by = c(pwindow, relative_obs_time, delay, delay_upper)
)

# Prepare data for Stan
stan_data <- pcd_as_stan_data(
delay_counts,
dist_id = 1, # Lognormal
primary_id = 1, # Uniform
param_bounds = list(lower = c(-Inf, 0), upper = c(Inf, Inf)),
primary_param_bounds = list(lower = numeric(0), upper = numeric(0)),
priors = list(location = c(0, 1), scale = c(1, 1)),
primary_priors = list(location = numeric(0), scale = numeric(0))
)

# Fit model
model <- suppressMessages(suppressWarnings(pcd_cmdstan_model()))
fit <- suppressMessages(suppressWarnings(model$sample(
data = stan_data,
seed = 123,
chains = 4,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
parallel_chains = 4,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
refresh = 0,
show_messages = FALSE,
iter_warmup = 500
)))

# Extract posterior
posterior <- fit$draws(c("params[1]", "params[2]"), format = "df")

# Check mean estimates
expect_equal(mean(posterior$`params[1]`), true_meanlog, tolerance = 0.05)
expect_equal(mean(posterior$`params[2]`), true_sdlog, tolerance = 0.05)

# Check credible intervals
ci_meanlog <- quantile(posterior$`params[1]`, c(0.05, 0.95))
ci_sdlog <- quantile(posterior$`params[2]`, c(0.05, 0.95))

expect_gt(true_meanlog, ci_meanlog[1])
expect_lt(true_meanlog, ci_meanlog[2])
expect_gt(true_sdlog, ci_sdlog[1])
expect_lt(true_sdlog, ci_sdlog[2])
})
Loading