Skip to content

Commit

Permalink
Create function opt_check_tbl_horizon_timediff. Resolves #31
Browse files Browse the repository at this point in the history
  • Loading branch information
annakrystalli committed Sep 25, 2023
1 parent c0e0c44 commit 9aa46d7
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 24 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export(is_info)
export(is_success)
export(not_pass)
export(opt_check_tbl_col_timediff)
export(opt_check_tbl_horizon_timediff)
export(read_model_out_file)
export(try_check)
export(validate_model_data)
Expand Down
79 changes: 79 additions & 0 deletions R/opt_check_tbl_horizon_timediff.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#' Check time difference between values in two date columns equal a defined period.
#'
#' @param t0_colname Character string. The name of the time zero date column.
#' @param t1_colname Character string. The name of the time zero + 1 time step date column.
#' @param horizon_colname Character string. The name of the horizon column.
#' Defaults to `"horizon"`.
#' @param timediff an object of class `lubridate` [`Period-class`] and length 1.
#' The period of a single horizon. Default to 1 week.
#' @inherit check_tbl_colnames params
#' @inherit check_tbl_col_types return
#' @export
opt_check_tbl_horizon_timediff <- function(tbl, file_path, hub_path, t0_colname,
t1_colname, horizon_colname = "horizon",
timediff = lubridate::weeks()) {
checkmate::assert_class(timediff, "Period")
checkmate::assert_scalar(timediff)
checkmate::assert_character(t0_colname, len = 1L)
checkmate::assert_character(t1_colname, len = 1L)
checkmate::assert_character(horizon_colname, len = 1L)
checkmate::assert_choice(t0_colname, choices = names(tbl))
checkmate::assert_choice(t1_colname, choices = names(tbl))
checkmate::assert_choice(horizon_colname, choices = names(tbl))

config_tasks <- hubUtils::read_config(hub_path, "tasks")
schema <- hubUtils::create_hub_schema(config_tasks,
partitions = NULL,
r_schema = TRUE
)
assert_column_date(t0_colname, schema)
assert_column_date(t1_colname, schema)
assert_column_integer(horizon_colname, schema)

if (!lubridate::is.Date(tbl[[t0_colname]])) {
tbl[, t0_colname] <- as.Date(tbl[[t0_colname]])
}
if (!lubridate::is.Date(tbl[[t1_colname]])) {
tbl[, t1_colname] <- as.Date(tbl[[t1_colname]])
}
if (!is.integer(tbl[[horizon_colname]])) {
tbl[, horizon_colname] <- as.integer(tbl[[horizon_colname]])
}

compare <- tbl[[t0_colname]] + (timediff * tbl[[horizon_colname]]) == tbl[[t1_colname]]
check <- all(compare)
if (check) {
details <- NULL
} else {
invalid_vals <- paste0(
tbl[[t1_colname]][!compare],
" (horizon = ", tbl[[horizon_colname]][!compare], ")"
) %>% unique()

details <- cli::format_inline(
"t1 var value{?s} {.val {invalid_vals}} are invalid."
)
}

capture_check_cnd(
check = check,
file_path = file_path,
msg_subject = cli::format_inline(
"Time differences between t0 var {.var {t0_colname}} and t1 var
{.var {t1_colname}}"
),
msg_verbs = c("all match", "do not all match"),
msg_attribute = cli::format_inline("expected period of {.val {timediff}} * {.var {horizon_colname}}."),
details = details
)
}

assert_column_integer <- function(colname, schema) {
if (schema[colname] != "integer") {
cli::cli_abort(
"Column {.arg colname} must be configured as {.cls integer} not
{.cls {schema[colname]}}.",
call = rlang::caller_call()
)
}
}
7 changes: 4 additions & 3 deletions inst/testhubs/flusight/hub-config/validations.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
default:
validate_model_data:
col_timediff:
fn: "opt_check_tbl_col_timediff"
horizon_timediff:
fn: "opt_check_tbl_horizon_timediff"
pkg: "hubValidations"
args:
t0_colname: "forecast_date"
t1_colname: "target_end_date"
timediff: !expr lubridate::weeks(2)
horizon_colname: "horizon"
timediff: !expr lubridate::weeks()
54 changes: 54 additions & 0 deletions man/opt_check_tbl_horizon_timediff.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/execute_custom_checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"testdata", "config", "validations.yml")))
Output
List of 1
$ col_timediff:List of 4
$ horizon_timediff:List of 4
..$ message : chr "Time differences between t0 var `forecast_date` and t1 var\n `target_end_date` all match expected period"| __truncated__
..$ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet"
..$ call : NULL
Expand All @@ -20,7 +20,7 @@
"testdata", "config", "validations-error.yml")))
Output
List of 1
$ col_timediff:List of 4
$ horizon_timediff:List of 4
..$ message : chr "Time differences between t0 var `forecast_date` and t1 var\n `target_end_date` do not all match expected"| __truncated__
..$ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet"
..$ call : NULL
Expand Down
72 changes: 72 additions & 0 deletions tests/testthat/_snaps/opt_check_tbl_horizon_timediff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# opt_check_tbl_horizon_timediff works

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date")
Output
<message/check_success>
Message:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 7d 0H 0M 0S * `horizon`.

---

Code
opt_check_tbl_horizon_timediff(tbl_chr, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date")
Output
<message/check_success>
Message:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 7d 0H 0M 0S * `horizon`.

---

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date")
Output
<warning/check_failure>
Warning:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` do not all match expected period of 7d 0H 0M 0S * `horizon`. t1 var value "2023-05-22 (horizon = 1)" are invalid.

---

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date", timediff = lubridate::weeks(2))
Output
<warning/check_failure>
Warning:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` do not all match expected period of 14d 0H 0M 0S * `horizon`. t1 var values "2023-05-15 (horizon = 1)" and "2023-05-22 (horizon = 2)" are invalid.

# opt_check_tbl_horizon_timediff fails correctly

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_dates")
Error <simpleError>
Assertion on 't1_colname' failed: Must be element of set {'forecast_date','target_end_date','horizon','target','location','output_type','output_type_id','value'}, but is 'target_end_dates'.

---

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = c("target_end_date", "forecast_date"))
Error <simpleError>
Assertion on 't1_colname' failed: Must have length 1, but has length 2.

---

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date", timediff = 7L)
Error <simpleError>
Assertion on 'timediff' failed: Must inherit from class 'Period', but has class 'integer'.

---

Code
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path, t0_colname = "forecast_date",
t1_colname = "target_end_date")
Error <rlang_error>
Column `colname` must be configured as <Date> not <character>.

8 changes: 4 additions & 4 deletions tests/testthat/_snaps/validate_model_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,22 @@
# validate_model_data with config function works

Code
validate_model_data(hub_path, file_path)[["col_timediff"]]
validate_model_data(hub_path, file_path)[["horizon_timediff"]]
Output
<message/check_success>
Message:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 14d 0H 0M 0S.
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 7d 0H 0M 0S * `horizon`.

---

Code
validate_model_data(hub_path, file_path, validations_cfg_path = system.file(
"testhubs/flusight/hub-config/validations.yml", package = "hubValidations"))[[
"col_timediff"]]
"horizon_timediff"]]
Output
<message/check_success>
Message:
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 14d 0H 0M 0S.
Time differences between t0 var `forecast_date` and t1 var `target_end_date` all match expected period of 7d 0H 0M 0S * `horizon`.

# validate_model_data print method work [plain]

Expand Down
88 changes: 88 additions & 0 deletions tests/testthat/test-opt_check_tbl_horizon_timediff.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
test_that("opt_check_tbl_horizon_timediff works", {
hub_path <- system.file("testhubs/flusight", package = "hubValidations")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
tbl <- hubValidations::read_model_out_file(file_path, hub_path)


expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date"
)
)

tbl_chr <- hubUtils::coerce_to_character(tbl)
expect_snapshot(
opt_check_tbl_horizon_timediff(tbl_chr, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date"
)
)

tbl$target_end_date[1] <- tbl$forecast_date[1] + lubridate::weeks(2)
expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date"
)
)

expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date",
timediff = lubridate::weeks(2)
)
)
})


test_that("opt_check_tbl_horizon_timediff fails correctly", {
hub_path <- system.file("testhubs/flusight", package = "hubValidations")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
tbl <- hubValidations::read_model_out_file(file_path, hub_path)

expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_dates"
),
error = TRUE
)

expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = c("target_end_date", "forecast_date")
),
error = TRUE
)

expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date",
timediff = 7L
),
error = TRUE
)

schema <- c(
forecast_date = "Date", target = "character", horizon = "integer",
location = "character", output_type = "character", output_type_id = "character",
value = "double", target_end_date = "character"
)
mockery::stub(
opt_check_tbl_horizon_timediff,
"hubUtils::create_hub_schema",
schema,
2
)
expect_snapshot(
opt_check_tbl_horizon_timediff(tbl, file_path, hub_path,
t0_colname = "forecast_date",
t1_colname = "target_end_date"
),
error = TRUE
)
})
4 changes: 2 additions & 2 deletions tests/testthat/test-validate_model_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test_that("validate_model_data with config function works", {
hub_path <- system.file("testhubs/flusight", package = "hubValidations")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
expect_snapshot(
validate_model_data(hub_path, file_path)[["col_timediff"]]
validate_model_data(hub_path, file_path)[["horizon_timediff"]]
)
expect_snapshot(
validate_model_data(
Expand All @@ -48,7 +48,7 @@ test_that("validate_model_data with config function works", {
"testhubs/flusight/hub-config/validations.yml",
package = "hubValidations"
)
)[["col_timediff"]]
)[["horizon_timediff"]]
)
})

Expand Down
Loading

0 comments on commit 9aa46d7

Please sign in to comment.