Skip to content

Commit

Permalink
rename truncate -> truncate_obs
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 8, 2024
1 parent 8453729 commit 6e5dff3
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ transformed parameters {
);
}
profile("truncate") {
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
obs_reports = truncate_obs(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
obs_reports = reports[1:ot];
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ transformed parameters {
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(secondary, trunc_rev_cmf, 0);
secondary = truncate_obs(secondary, trunc_rev_cmf, 0);
}
}

Expand Down
6 changes: 3 additions & 3 deletions inst/stan/estimate_truncation.stan
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ transformed parameters{
vector[t] last_obs;
// reconstruct latest data without truncation

last_obs = truncate(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
last_obs = truncate_obs(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
// apply truncation to latest dataset to map back to previous data sets and
// add noise term
for (i in 1:(obs_sets - 1)) {
trunc_obs[1:(end_t[i] - start_t[i] + 1), i] =
truncate(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
truncate_obs(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
}
}
}
Expand Down Expand Up @@ -80,7 +80,7 @@ generated quantities {
matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs;
// reconstruct all truncated datasets using posterior of the truncation distribution
for (i in 1:obs_sets) {
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate(
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs(
to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1
);
}
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ vector scale_obs(vector reports, real frac_obs) {
*
* @return A vector of truncated reports.
*/
vector truncate(vector reports, vector trunc_rev_cmf, int reconstruct) {
vector truncate_obs(vector reports, vector trunc_rev_cmf, int reconstruct) {
int t = num_elements(reports);
int trunc_max = num_elements(trunc_rev_cmf);
vector[t] trunc_reports = reports;
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
reports[i] = to_row_vector(truncate(
reports[i] = to_row_vector(truncate_obs(
to_vector(reports[i]), trunc_rev_cmf, 0)
);
}
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(
secondary = truncate_obs(
secondary, trunc_rev_cmf, 0
);
}
Expand Down
2 changes: 0 additions & 2 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) {
target_dir = system.file("stan/functions", package = "EpiNow2")
)
)
# avoid problems due to base::truncate
stan_truncate <- truncate
}
}

Expand Down
16 changes: 8 additions & 8 deletions tests/testthat/test-stan-truncate.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
skip_on_cran()
skip_on_os("windows")

test_that("truncate() can perform truncation as expected", {
test_that("truncate_obs() can perform truncation as expected", {
reports <- c(10, 20, 30, 40, 50)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] * trunc_rev_cmf)
expect_equal(stan_truncate(reports, trunc_rev_cmf, FALSE), expected)
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate() can perform reconstruction as expected", {
test_that("truncate_obs() can perform reconstruction as expected", {
reports <- c(10, 20, 15, 8, 10)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] / trunc_rev_cmf)
expect_equal(stan_truncate(reports, trunc_rev_cmf, TRUE), expected)
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})

test_that("truncate() can handle longer trunc_rev_cmf than reports", {
test_that("truncate_obs() can handle longer trunc_rev_cmf than reports", {
reports <- c(10, 20, 30)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports * trunc_rev_cmf[3:5]
expect_equal(stan_truncate(reports, trunc_rev_cmf, FALSE), expected)
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate() can handle reconstruction with longer trunc_rev_cmf than reports", {
test_that("truncate_obs() can handle reconstruction with longer trunc_rev_cmf than reports", {
reports <- c(10, 16, 15)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports / trunc_rev_cmf[3:5]
expect_equal(stan_truncate(reports, trunc_rev_cmf, TRUE), expected)
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})

0 comments on commit 6e5dff3

Please sign in to comment.