Skip to content

Commit

Permalink
Merge pull request #160 from hubverse-org/ak/v4-point-estimate-nulls/156
Browse files Browse the repository at this point in the history
  • Loading branch information
annakrystalli authored Nov 14, 2024
2 parents 9a3101e + bb36c4d commit d0fc091
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 28 deletions.
110 changes: 82 additions & 28 deletions R/expand_model_out_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,10 @@ expand_model_out_grid <- function(config_tasks,
# retired
config_tid <- hubUtils::get_config_tid(config_tasks = config_tasks)

output_type_l <- purrr::map(
round_config[["model_tasks"]],
function(.x) {
out <- .x[["output_type"]]
if (is.null(output_types)) {
out
} else {
mt_output_types <- output_types[output_types %in% names(out)]
out[mt_output_types]
}
}
) %>%
purrr::map(
~ extract_mt_output_type_ids(.x, config_tid)
) %>%
output_type_l <- subset_rnd_output_types(round_config, output_types) %>%
extract_rnd_output_type_ids(config_tid) %>%
process_grid_inputs(required_vals_only = required_vals_only) %>%
purrr::map(function(.x) {
purrr::compact(.x)
})
purrr::map(~ purrr::compact(.x))

# Expand output grid individually for each modeling task and output type.
grid <- purrr::map2(
Expand Down Expand Up @@ -241,6 +226,29 @@ expand_model_out_grid <- function(config_tasks,
)
}

# Subset output types according to `output_types` from all model_task objects in
# a round. If `output_types` is `NULL`, all output types for each model task are
# returned.
subset_rnd_output_types <- function(round_config, output_types) {
purrr::map(
round_config[["model_tasks"]],
~ subset_mt_output_types(.x, output_types)
)
}

# Subset model_task object output types according to `output_types`.
# If `output_types` is `NULL`, all output types are returned.
subset_mt_output_types <- function(model_task, output_types) {
out <- model_task[["output_type"]]
if (is.null(output_types)) {
out
} else {
mt_output_types <- output_types[output_types %in% names(out)]
out[mt_output_types]
}
}


# Extracts/collapses individual task ID values depending on whether all or just required
# values are needed.
process_grid_inputs <- function(x, required_vals_only = FALSE) {
Expand Down Expand Up @@ -538,25 +546,71 @@ get_sample_n <- function(x, config_tid) {
length()
}


# Extract the output_type_id values for each model_task object in a round.
# Input should be the output of subset_rnd_output_types.
# config_tid is the name of the output_type_id column in the config schema used
# for back-compatibility with schema versions < v2.0.0. Returns a list of
# `required` and `optional` or just `required` vectors of values as appropriate for
# each output type in each model task in the round.
extract_rnd_output_type_ids <- function(x, config_tid) {
purrr::map(x, ~ extract_mt_output_type_ids(.x, config_tid))
}
# Extract the output_type_id values from a model_task object.
# config_tid is the name of the output_type_id column in the config schema used
# for back-compatibility with schema versions < v2.0.0. Returns a list of
# `required` and `optional` or just `required` vectors of values as appropriate for
# each output type in the model task.
extract_mt_output_type_ids <- function(x, config_tid) {
purrr::map(
x,
function(.x) {
if (config_tid %in% names(.x)) {
.x[[config_tid]]
} else if ("output_type_id_params" %in% names(.x)) {
if (.x[["output_type_id_params"]][["is_required"]]) {
list(required = NA, optional = NULL)
} else {
list(required = NULL, optional = NA)
}
} else {
NULL
output_type_ids <- .x[[config_tid]]
if (valid_output_type_ids(output_type_ids)) {
return(output_type_ids)
}

# If dealing with a `NULL` output_type_ids object or a v4 schema version point
# estimate output type, determine first if .
is_required <- isTRUE(.x[["is_required"]]) ||
isTRUE(.x[["output_type_id_params"]][["is_required"]])

null_output_type_ids(is_required)
}
)
}

valid_output_type_ids <- function(output_type_ids) {
# If output_type_id values are provided and when dealing with an older
# config version that has "required" and"optional" fields or extract output_type_id values
has_output_type_ids <- !is.null(output_type_ids)
pre_v4 <- isTRUE(
all.equal(
sort(names(output_type_ids)),
sort(c("required", "optional"))
)
)
# In post v4 config schema versions, when not NULL, a single `required` element is
# a valid output type id configuration and should be returned as is
required_not_null <- !is.null(output_type_ids[["required"]])

# Valid output type id configurations cannot be `NULL` and must either:
# have both `required` and `optional` elements or
# be a single non-NULL `required` element in post v4 schema versions
has_output_type_ids && (pre_v4 || required_not_null)
}

# Create a list of NULL or NA required and optional output type id values depending
# on whether the output type is required or optional. Allows us to use current
# infrastructure to convert `NULL`s to `NA`s in a back-compatible way.
null_output_type_ids <- function(is_required) {
if (is_required) {
list(required = NA, optional = NULL)
} else {
list(required = NULL, optional = NA)
}
}

validate_output_types <- function(output_types, config_tasks, round_id,
call = rlang::caller_call()) {
checkmate::assert_character(output_types, null.ok = TRUE)
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-expand_model_out_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -551,5 +551,21 @@ test_that("(#123) expand_output_type_grid() returns expected outputs with option
)
expect_equal(nrow(i_have_eight_rows), 8)
expect_equal(ncol(i_have_eight_rows), 3)
})

test_that("v4 point estimate output type IDs extracted correctly as NAs", {
hub_path <- system.file("testhubs", "v4", "flusight", package = "hubUtils")
file_name <- "hub-baseline/2023-05-01-hub-baseline.csv"
round_id <- parse_file_name(file_name)$round_id
config_tasks <- suppressWarnings(read_config(hub_path = hub_path))

expect_true(
expand_model_out_grid(
config_tasks = config_tasks,
round_id = round_id,
output_types = "mean",
)[["output_type_id"]] |>
is.na() |>
all()
)
})

0 comments on commit d0fc091

Please sign in to comment.