Skip to content

Commit

Permalink
Merge pull request #108 from hubverse-org/93/add-derived-tid-arg
Browse files Browse the repository at this point in the history
93/add derived_task_id argument
  • Loading branch information
annakrystalli authored Aug 15, 2024
2 parents 8815afc + 1811ede commit 2cce180
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 41 deletions.
29 changes: 29 additions & 0 deletions R/config_tasks-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
get_round_config <- function(config_tasks, round_id) {
round_idx <- hubUtils::get_round_idx(config_tasks, round_id)
purrr::pluck(
config_tasks,
"rounds",
round_idx
)
}

get_round_output_types <- function(config_tasks, round_id) {
round_config <- get_round_config(config_tasks, round_id)
purrr::map(
round_config[["model_tasks"]],
~ .x[["output_type"]]
)
}

get_round_output_type_names <- function(config_tasks, round_id,
collapse = TRUE) {
out <- get_round_output_types(config_tasks, round_id) %>%
purrr::map(names)

if (collapse) {
purrr::flatten_chr(out) %>%
unique()
} else {
out
}
}
135 changes: 99 additions & 36 deletions R/expand_model_out_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#' in the round. Can be used to override the compound task ID set defined in the
#' config. If `NULL` is provided for a given modeling task, a compound task ID set of
#' all task IDs is used.
#' @param output_types character vector of output type names to include.
#' @param output_types Character vector of output type names to include.
#' Use to subset for grids for specific output types.
#' @param derived_task_ids Character vector of derived task ID names (task IDs whose
#' values depend on other task IDs) to ignore. Columns for such task ids will
#' contain `NA`s.
#'
#' @return If `bind_model_tasks = TRUE` (default) a tibble or arrow table
#' containing all possible task ID and related output type ID
Expand Down Expand Up @@ -90,9 +93,9 @@
#' include_sample_ids = TRUE
#' )
#' # Hub with sample output type and compound task ID structure
#' config_tasks <- hubUtils::read_config_file(system.file("config", "tasks-comp-tid.json",
#' package = "hubValidations"
#' ))
#' config_tasks <- hubUtils::read_config_file(
#' system.file("config", "tasks-comp-tid.json", package = "hubValidations")
#' )
#' expand_model_out_grid(config_tasks,
#' round_id = "2022-12-26",
#' include_sample_ids = TRUE
Expand All @@ -116,6 +119,30 @@
#' NULL
#' )
#' )
#' # Subset output types
#' config_tasks <- hubUtils::read_config(
#' system.file("testhubs", "samples", package = "hubValidations")
#' )
#' expand_model_out_grid(config_tasks,
#' round_id = "2022-10-29",
#' include_sample_ids = TRUE,
#' bind_model_tasks = FALSE,
#' output_types = c("sample", "pmf"),
#' )
#' expand_model_out_grid(config_tasks,
#' round_id = "2022-10-29",
#' include_sample_ids = TRUE,
#' bind_model_tasks = TRUE,
#' output_types = "sample",
#' )
#' # Ignore derived task IDs
#' expand_model_out_grid(config_tasks,
#' round_id = "2022-10-29",
#' include_sample_ids = TRUE,
#' bind_model_tasks = FALSE,
#' output_types = "sample",
#' derived_task_ids = "target_end_date"
#' )
expand_model_out_grid <- function(config_tasks,
round_id,
required_vals_only = FALSE,
Expand All @@ -129,15 +156,21 @@ expand_model_out_grid <- function(config_tasks,
bind_model_tasks = TRUE,
include_sample_ids = FALSE,
compound_taskid_set = NULL,
output_types = NULL) {
output_types = NULL,
derived_task_ids = NULL) {
checkmate::assert_list(compound_taskid_set, null.ok = TRUE)
output_type_id_datatype <- rlang::arg_match(output_type_id_datatype)
output_types <- validate_output_types(output_types, config_tasks, round_id)
derived_task_ids <- validate_derived_task_ids(
derived_task_ids,
config_tasks, round_id
)
round_config <- get_round_config(config_tasks, round_id)

task_id_l <- purrr::map(
round_config[["model_tasks"]],
~ .x[["task_ids"]] %>%
derived_taskids_to_na(derived_task_ids) %>%
null_taskids_to_na()
) %>%
# Fix round_id value to current round_id in round_id variable column
Expand Down Expand Up @@ -345,6 +378,22 @@ null_taskids_to_na <- function(model_task) {
)
}

# Set derived task_ids to all NULL values.
derived_taskids_to_na <- function(model_task, derived_task_ids) {
if (!is.null(derived_task_ids)) {
purrr::modify_at(
model_task,
.at = derived_task_ids,
.f = ~ list(
required = NULL,
optional = NA
)
)
} else {
model_task
}
}

# Adds example sample ids to the output type id column which are unique
# across multiple modeling task groups. Only apply to v3 and above sample output
# type configurations.
Expand Down Expand Up @@ -479,37 +528,6 @@ extract_mt_output_type_ids <- function(x, config_tid) {
)
}


get_round_config <- function(config_tasks, round_id) {
round_idx <- hubUtils::get_round_idx(config_tasks, round_id)
purrr::pluck(
config_tasks,
"rounds",
round_idx
)
}

get_round_output_types <- function(config_tasks, round_id) {
round_config <- get_round_config(config_tasks, round_id)
purrr::map(
round_config[["model_tasks"]],
~ .x[["output_type"]]
)
}

get_round_output_type_names <- function(config_tasks, round_id,
collapse = TRUE) {
out <- get_round_output_types(config_tasks, round_id) %>%
purrr::map(names)

if (collapse) {
purrr::flatten_chr(out) %>%
unique()
} else {
out
}
}

validate_output_types <- function(output_types, config_tasks, round_id,
call = rlang::caller_call()) {
checkmate::assert_character(output_types, null.ok = TRUE)
Expand All @@ -529,3 +547,48 @@ validate_output_types <- function(output_types, config_tasks, round_id,
}
valid_output_types
}

validate_derived_task_ids <- function(derived_task_ids, config_tasks, round_id) {
checkmate::assert_character(derived_task_ids, null.ok = TRUE)
if (is.null(derived_task_ids)) {
return(NULL)
}
round_task_ids <- hubUtils::get_round_task_id_names(config_tasks, round_id)
valid_task_ids <- intersect(derived_task_ids, round_task_ids)
if (length(valid_task_ids) < length(derived_task_ids)) {
cli::cli_warn(
c(
"x" = "{.val {setdiff(derived_task_ids, round_task_ids)}}
{?is/are} not valid task ID{?s}. Ignored.",
"i" = "{.arg derived_task_ids} must be a member of: {.val {round_task_ids}}"
),
call = rlang::caller_call()
)
}
model_tasks <- hubUtils::get_round_model_tasks(config_tasks, round_id)
has_required <- purrr::map(
model_tasks,
~ .x[["task_ids"]][valid_task_ids] %>%
purrr::map_lgl(
~ !is.null(.x$required)
)
) %>% purrr::reduce(`|`)

Check warning on line 575 in R/expand_model_out_grid.R

View workflow job for this annotation

GitHub Actions / lint

file=R/expand_model_out_grid.R,line=575,col=5,[pipe_continuation_linter] `%>%` should always have a space before it and a new line after it, unless the full pipeline fits on one line.
if (any(has_required)) {
cli::cli_abort(
c(
"x" = "Derived task IDs cannot have required task ID values.",
"!" = "{.val {names(has_required)[has_required]}} ha{?s/ve}
required task ID values. Ignored."
),
call = rlang::caller_call()
)
}
valid_task_ids <- intersect(
valid_task_ids,
names(has_required)[!has_required]
)
if (length(valid_task_ids) == 0L) {
return(NULL)
}
valid_task_ids
}
39 changes: 34 additions & 5 deletions man/expand_model_out_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions tests/testthat/_snaps/expand_model_out_grid.md
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,58 @@
x "random" is not valid output type.
i `output_types` must be members of: "sample", "mean", and "pmf"

# expand_model_out_grid derived_task_ids ignoring works

Code
expand_model_out_grid(config_tasks, round_id = "2022-10-22",
include_sample_ids = FALSE, bind_model_tasks = TRUE, output_types = "sample",
derived_task_ids = "target_end_date")
Output
# A tibble: 80 x 8
reference_date target horizon location variant target_end_date output_type
<date> <chr> <int> <chr> <chr> <date> <chr>
1 2022-10-22 wk inc f~ 0 US AA NA sample
2 2022-10-22 wk inc f~ 1 US AA NA sample
3 2022-10-22 wk inc f~ 2 US AA NA sample
4 2022-10-22 wk inc f~ 3 US AA NA sample
5 2022-10-22 wk inc f~ 0 01 AA NA sample
6 2022-10-22 wk inc f~ 1 01 AA NA sample
7 2022-10-22 wk inc f~ 2 01 AA NA sample
8 2022-10-22 wk inc f~ 3 01 AA NA sample
9 2022-10-22 wk inc f~ 0 02 AA NA sample
10 2022-10-22 wk inc f~ 1 02 AA NA sample
# i 70 more rows
# i 1 more variable: output_type_id <chr>

---

Code
expand_model_out_grid(config_tasks, round_id = "2022-10-22",
include_sample_ids = TRUE, bind_model_tasks = TRUE, output_types = "sample",
derived_task_ids = "target_end_date", required_vals_only = TRUE)
Condition
Warning:
The compound task IDs horizon and target_end_date have all optional values. Representation of compound sample modeling tasks is not fully specified.
Output
# A tibble: 4 x 5
reference_date location variant output_type output_type_id
<date> <chr> <chr> <chr> <chr>
1 2022-10-22 US AA sample 1
2 2022-10-22 01 AA sample 2
3 2022-10-22 US BB sample 3
4 2022-10-22 01 BB sample 4

---

Code
expand_model_out_grid(config_tasks, round_id = "2022-10-22",
include_sample_ids = FALSE, bind_model_tasks = FALSE, output_types = "sample",
derived_task_ids = c("location", "variant"))
Condition
Error in `expand_model_out_grid()`:
x Derived task IDs cannot have required task ID values.
! "location" and "variant" have required task ID values. Ignored.

# expand_model_out_grid errors correctly

Code
Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test-expand_model_out_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,43 @@ test_that("expand_model_out_grid output type subsetting works", {
)
})

test_that("expand_model_out_grid derived_task_ids ignoring works", {
config_tasks <- hubUtils::read_config(test_path("testdata", "hub-spl"))

expect_snapshot(
expand_model_out_grid(config_tasks,
round_id = "2022-10-22",

Check warning on line 404 in tests/testthat/test-expand_model_out_grid.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-expand_model_out_grid.R,line=404,col=26,[indentation_linter] Indentation should be 6 spaces but is 26 spaces.
include_sample_ids = FALSE,
bind_model_tasks = TRUE,
output_types = "sample",
derived_task_ids = "target_end_date"
)
)
expect_snapshot(
expand_model_out_grid(config_tasks,
round_id = "2022-10-22",

Check warning on line 413 in tests/testthat/test-expand_model_out_grid.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-expand_model_out_grid.R,line=413,col=26,[indentation_linter] Indentation should be 6 spaces but is 26 spaces.
include_sample_ids = TRUE,
bind_model_tasks = TRUE,
output_types = "sample",
derived_task_ids = "target_end_date",
required_vals_only = TRUE
)
)

expect_snapshot(
expand_model_out_grid(config_tasks,
round_id = "2022-10-22",

Check warning on line 424 in tests/testthat/test-expand_model_out_grid.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-expand_model_out_grid.R,line=424,col=26,[indentation_linter] Indentation should be 6 spaces but is 26 spaces.
include_sample_ids = FALSE,
bind_model_tasks = FALSE,
output_types = "sample",
derived_task_ids = c("location", "variant")
),
error = TRUE
)
})



test_that("expand_model_out_grid errors correctly", {
# Specifying a round in a hub with multiple rounds
hub_con <- hubData::connect_hub(
Expand Down

0 comments on commit 2cce180

Please sign in to comment.