diff --git a/R/create.R b/R/create.R index d9d45992d..d3e8bedc6 100644 --- a/R/create.R +++ b/R/create.R @@ -273,7 +273,8 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, rt <- rt_opts( use_rt = FALSE, future = "project", - gp_on = "R0" + gp_on = "R0", + rw = 0 ) } # define future Rt arguments @@ -283,6 +284,10 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, ) # apply random walk if (rt$rw != 0) { + if (is.null(breakpoints)) { + stop("breakpoints must be supplied when using random walk") + } + breakpoints <- seq_along(breakpoints) breakpoints <- floor(breakpoints / rt$rw) if (!(rt$future == "project")) { @@ -292,11 +297,12 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, } } }else { - if (is.null(breakpoints) || sum(breakpoints) == 0) { - rt$use_breakpoints <- FALSE - } breakpoints <- cumsum(breakpoints) } + + if (sum(breakpoints) == 0) { + rt$use_breakpoints <- FALSE + } # add a shift for 0 effect in breakpoints breakpoints <- breakpoints + 1 diff --git a/tests/testthat/test-create_rt_date.R b/tests/testthat/test-create_rt_date.R new file mode 100644 index 000000000..748ae80d4 --- /dev/null +++ b/tests/testthat/test-create_rt_date.R @@ -0,0 +1,88 @@ +test_that("create_rt_data returns expected default values", { + result <- create_rt_data() + + expect_type(result, "list") + expect_equal(result$r_mean, 1) + expect_equal(result$r_sd, 1) + expect_equal(result$estimate_r, 1) + expect_equal(result$bp_n, 0) + expect_equal(result$breakpoints, numeric(0)) + expect_equal(result$future_fixed, 1) + expect_equal(result$fixed_from, 0) + expect_equal(result$pop, 0) + expect_equal(result$stationary, 0) + expect_equal(result$future_time, 0) +}) + +test_that("create_rt_data handles NULL rt input correctly", { + result <- create_rt_data(rt = NULL) + + expect_equal(result$estimate_r, 0) + expect_equal(result$future_fixed, 0) + expect_equal(result$stationary, 1) +}) + +test_that("create_rt_data handles custom rt_opts correctly", { + custom_rt <- rt_opts( + prior = list(mean = 2, sd = 0.5), + use_rt = FALSE, + rw = 0, + use_breakpoints = FALSE, + future = "project", + gp_on = "R0", + pop = 1000000 + ) + + result <- create_rt_data(rt = custom_rt, horizon = 7) + + expect_equal(result$r_mean, 2) + expect_equal(result$r_sd, 0.5) + expect_equal(result$estimate_r, 0) + expect_equal(result$pop, 1000000) + expect_equal(result$stationary, 1) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles breakpoints correctly", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = c(1, 0, 1, 0, 1)) + + expect_equal(result$bp_n, 3) + expect_equal(result$breakpoints, c(2, 2, 3, 3, 4)) +}) + +test_that("create_rt_data handles random walk correctly", { + result <- create_rt_data(rt_opts(rw = 2), + breakpoints = rep(1, 10)) + + expect_equal(result$bp_n, 5) + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 5, 5, 6)) +}) + +test_that("create_rt_data throws error for invalid inputs", { + expect_error(create_rt_data(rt_opts(rw = 2)), + "breakpoints must be supplied when using random walk") +}) + +test_that("create_rt_data handles future projections correctly", { + result <- create_rt_data(rt_opts(future = "project"), horizon = 7) + + expect_equal(result$future_fixed, 0) + expect_equal(result$fixed_from, 0) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles zero sum breakpoints", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = rep(0, 5)) + + expect_equal(result$bp_n, 0) +}) + +test_that("create_rt_data adjusts breakpoints for horizon", { + result <- create_rt_data(rt_opts(rw = 2, future = "latest"), + breakpoints = rep(1, 10), + horizon = 3) + + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 4, 4, 4)) +})