Skip to content

Commit

Permalink
Merge pull request #126 from hubverse-org/prepare-for-cran2
Browse files Browse the repository at this point in the history
Prepare for cran pt. 2
  • Loading branch information
lshandross authored Sep 5, 2024
2 parents 054edcd + dda8e81 commit c684690
Show file tree
Hide file tree
Showing 25 changed files with 350 additions and 134 deletions.
4 changes: 2 additions & 2 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ Our procedures for contributing bigger changes, code in particular, generally fo
## Code of Conduct

Please note that the hubEnsembles project is released with a
[Contributor Code of Conduct](.github/CODE_OF_CONDUCT.md). By contributing to this
project you agree to abide by its terms.
[Contributor Code of Conduct](https://hubverse-org.github.io/hubEnsembles/CODE_OF_CONDUCT.html).
By contributing to this project you agree to abide by its terms.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ scrap code.r
hubEnsembles.code-workspace
.Rprofile
Rplot.png
/dev
revdep/
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Fix bug in `simple_ensemble()` that produces invalid distributions for certain weighted medians (#122)
* Replace magrittr pipe (`%>%`) with base R 4.1 pipe (`|>`)
* Simplify examples

# hubEnsembles 0.1.5

Expand Down
65 changes: 65 additions & 0 deletions R/example_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#' Example model output data for `simple_ensemble()`
#'
#' Toy model output data formatted according to hubverse standards
#' to be used in the examples for `simple_ensemble()`
#'
#' @format ## `model_outputs`
#' A data frame with 24 rows and 8 columns:
#' \describe{
#' \item{model_id}{model ID}
#' \item{location}{FIPS codes}
#' \item{horizon}{forecast horizon}
#' \item{target}{forecast target}
#' \item{target_date}{date that the forecast is for}
#' \item{output_type}{type of forecast}
#' \item{output_type_id}{output type ID}
#' \item{value}{forecast value}
#' }
"model_outputs"

#' Example weights data for `simple_ensemble()`
#'
#' Toy weights data formatted according to hubverse standards
#' to be used in the examples for `simple_ensemble()`
#'
#' @format ## `fweights`
#' A data frame with 8 rows and 3 columns:
#' \describe{
#' \item{model_id}{model ID}
#' \item{location}{FIPS codes}
#' \item{weight}{weight}
#' }
"fweights"


#' Example model output data for `linear_pool()`
#'
#' Toy model output data formatted according to hubverse standards
#' to be used in the examples for `linear_pool()`. The predictions included
#' are taken from three normal distributions with means -3, 0, 3 and
#' all standard deviations 1.
#'
#' @format ## `component_outputs`
#' A data frame with 123 rows and 5 columns:
#' \describe{
#' \item{model_id}{model ID}
#' \item{target}{forecast target}
#' \item{output_type}{type of forecast}
#' \item{output_type_id}{output type ID}
#' \item{value}{forecast value}
#' }
"component_outputs"

#' Example weights data for `linear_pool()`
#'
#' Toy weights data formatted according to hubverse standards
#' to be used in the examples for `linear_pool()`. Weights are 0.25, 0.5, 0.25.
#'
#' @format ## `weights`
#' A data frame with 3 rows and 2 columns:
#' \describe{
#' \item{model_id}{model ID}
#' \item{location}{FIPS codes}
#' \item{weight}{weight}
#' }
"weights"
40 changes: 11 additions & 29 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,26 @@
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fn`.
#'
#' @return a `model_out_tbl` object of ensemble predictions. Note that any
#' additional columns in the input `model_outputs` are dropped.
#' additional columns in the input `model_out_tbl` are dropped.
#'
#' @export
#'
#' @examples
#' # We illustrate the calculation of a linear pool when we have quantiles from the
#' # component models. We take the components to be normal distributions with
#' # means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25.
#' library(purrr)
#' component_ids <- letters[1:3]
#' component_weights <- c(0.25, 0.5, 0.25)
#' component_means <- c(-3, 0, 3)
#' data(component_outputs)
#' data(weights)
#'
#' lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs
#' ps <- rep(0, length(lp_qs))
#' for (m in seq_len(3)) {
#' ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m])
#' }
#'
#' component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) |> unlist()
#' component_outputs <- data.frame(
#' stringsAsFactors = FALSE,
#' model_id = rep(component_ids, each = length(lp_qs)),
#' target = "inc death",
#' output_type = "quantile",
#' output_type_id = ps,
#' value = component_qs)
#'
#' lp_from_component_qs <- linear_pool(
#' component_outputs,
#' weights = data.frame(model_id = component_ids, weight = component_weights))
#' expected_quantiles <- seq(from = -5, to = 5, by = 0.25)
#' lp_from_component_qs <- linear_pool(component_outputs, weights)
#'
#' head(lp_from_component_qs)
#' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3,
#' check.attributes=FALSE)
#' all.equal(lp_from_component_qs$value, expected_quantiles, tolerance = 1e-3,
#' check.attributes = FALSE)
#'

linear_pool <- function(model_outputs, weights = NULL,
linear_pool <- function(model_out_tbl, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
Expand All @@ -76,18 +58,18 @@ linear_pool <- function(model_outputs, weights = NULL,

# validate_ensemble_inputs
valid_types <- c("mean", "quantile", "cdf", "pmf")
validated_inputs <- model_outputs |>
validated_inputs <- model_out_tbl |>
validate_ensemble_inputs(weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
model_out_tbl_validated <- validated_inputs$model_out_tbl
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

# calculate linear opinion pool for different types
ensemble_model_outputs <- model_outputs_validated |>
ensemble_model_outputs <- model_out_tbl_validated |>
dplyr::group_split("output_type") |>
purrr::map(.f = function(split_outputs) {
type <- split_outputs$output_type[1]
Expand Down
8 changes: 4 additions & 4 deletions R/linear_pool_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type.
#' @importFrom rlang .data

linear_pool_quantile <- function(model_outputs, weights = NULL,
linear_pool_quantile <- function(model_out_tbl, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {
quantile_levels <- unique(model_outputs$output_type_id)
quantile_levels <- unique(model_out_tbl$output_type_id)

if (is.null(weights)) {
group_by_cols <- task_id_cols
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = as.numeric(quantile_levels)))
} else {
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

model_outputs <- model_outputs |>
model_out_tbl <- model_out_tbl |>
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
Expand All @@ -34,7 +34,7 @@ linear_pool_quantile <- function(model_outputs, weights = NULL,
}

sample_q_lvls <- seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples]
quantile_outputs <- model_outputs |>
quantile_outputs <- model_out_tbl |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(
Expand Down
35 changes: 24 additions & 11 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_out_tbl` with component
#' @param model_out_tbl an object of class `model_out_tbl` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
Expand All @@ -20,8 +20,8 @@
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, `"output_type"`,
#' `model_out_tbl` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_out_tbl` other than `"model_id"`, `"output_type"`,
#' `"output_type_id"`, and `"value"` are used as task ids.
#'
#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's
Expand All @@ -38,24 +38,37 @@
#' calculation issue that results in invalid distributions.
#'
#' @return a `model_out_tbl` object of ensemble predictions. Note that
#' any additional columns in the input `model_outputs` are dropped.
#' any additional columns in the input `model_out_tbl` are dropped.
#'
#' @export
simple_ensemble <- function(model_outputs, weights = NULL,
#'
#' @examples
#' # Calculate a weighted median in two ways
#' data(model_outputs)
#' data(fweights)
#'
#' weighted_median1 <- simple_ensemble(model_outputs, weights = fweights,
#' agg_fun = stats::median)
#' weighted_median2 <- simple_ensemble(model_outputs, weights = fweights,
#' agg_fun = matrixStats::weightedMedian)
#' all.equal(weighted_median1, weighted_median2)
#'

simple_ensemble <- function(model_out_tbl, weights = NULL,
weights_col_name = "weight",
agg_fun = "mean", agg_args = list(),
agg_fun = mean, agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL) {

# validate_ensemble_inputs
valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
validated_inputs <- model_outputs |>
validated_inputs <- model_out_tbl |>
validate_ensemble_inputs(weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
model_out_tbl_validated <- validated_inputs$model_out_tbl
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

Expand All @@ -65,14 +78,14 @@ simple_ensemble <- function(model_outputs, weights = NULL,
weight_by_cols <-
colnames(weights_validated)[colnames(weights_validated) != weights_col_name]

model_outputs_validated <- model_outputs_validated |>
model_out_tbl_validated <- model_out_tbl_validated |>
dplyr::left_join(weights_validated, by = weight_by_cols)

agg_fun <- match.fun(agg_fun)

if (identical(agg_fun, mean)) {
agg_fun <- matrixStats::weightedMean
} else if (identical(agg_fun, median)) {
} else if (identical(agg_fun, stats::median)) {
agg_fun <- matrixStats::weightedMedian
}

Expand All @@ -86,7 +99,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,
}

group_by_cols <- c(task_id_cols_validated, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs_validated |>
ensemble_model_outputs <- model_out_tbl_validated |>
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
Expand Down
Loading

0 comments on commit c684690

Please sign in to comment.