Skip to content

Commit

Permalink
updated function and tests for get_flusight_bin_endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
elray1 committed Sep 25, 2024
1 parent 758083e commit faa70dd
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 21 deletions.
34 changes: 13 additions & 21 deletions R/get_flusight_bin_endpoints.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#' including the columns `location`, `date`, and `value`
#' @param location_meta Data frame with metadata about locations for FluSight,
#' matching the format of the file in cdcepi/FluSight-forecast-hub/auxiliary-data/locations.csv
#' @param season String naming the season: either "2022/23" or "2023/24"
#' @param season String naming the season: only "2023/24" is supported
#'
#' @details Compute the bin endpoints used for categorical targets in FluSight
#' in the 2022/23 or 2023/24 season. In both seasons, there were 5 categories:
#' in the 2023/24 season. In that season, there were 5 categories:
#' "large decrease", "decrease", "stable", "increase", and "large increase".
#' The bin endpoints have the form
#' value +/- max(multiplier * population / 100k, min_count),
Expand Down Expand Up @@ -56,36 +56,28 @@ get_flusight_bin_endpoints <- function(target_ts, location_meta, season) {
}

get_flusight_bin_endpoint_meta <- function(season) {
if (season == "2022/23") {
return(get_flusight_bin_endpoint_meta_2223())
} else if (season == "2023/24") {
if (season == "2023/24") {
return(get_flusight_bin_endpoint_meta_2324())
} else {
stop("unsupported season")
}
}

get_flusight_bin_endpoint_meta_2223 <- function() {
bin_endpoint_meta <- data.frame(
output_type_id = c("large_decrease", "decrease", "stable", "increase", "large_increase"),
lower_sign = c(-1, -1, -1, 1, 1),
lower_rate_multiplier = c(Inf, 2, 1, 1, 2),
lower_min_count_change = c(Inf, 40, 20, 20, 40),
upper_sign = c(-1, -1, 1, 1, 1),
upper_rate_multiplier = c(2, 1, 1, 2, Inf),
upper_min_count_change = c(40, 20, 20, 40, Inf),
stringsAsFactors = FALSE
)

return(bin_endpoint_meta)
}

get_flusight_bin_endpoint_meta_2324 <- function() {
# we use 9.5 for the minimum count change because for low-population states,
# the intervals have the half-open form (value - 9.5, value + 9.5] for stable,
# (value - ***, value - 9.5] for decrease and large decrease,
# (value + 9.5, value + ***] for increase and large increase,
# where *** is something coming from a population rate per 100k since all
# state populations / 100000 are greater than 5
# In all cases, this says a count change of less than 10 is stable
# and a count change of 10 or more is non-stable
bin_endpoint_meta <- tidyr::expand_grid(
horizon = 0:3,
output_type_id = list(c("large_decrease", "decrease", "stable", "increase", "large_increase")),
lower_min_count_change = 10,
upper_min_count_change = 10
lower_min_count_change = 9.5,
upper_min_count_change = 9.5
)

bin_endpoint_meta$rate_multiplier_endpoints <- list(
Expand Down
54 changes: 54 additions & 0 deletions tests/testthat/fixtures/location_meta_24.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
abbreviation,location,location_name,population,...5,count_rate1,count_rate2,count_rate2p5,count_rate3,count_rate4,count_rate5
US,US,US,332200066,NA,3322,6644,8305,9966,13288,16610
AL,01,Alabama,5063778,NA,51,101,127,152,203,253
AK,02,Alaska,711426,NA,7,14,18,21,28,36
AZ,04,Arizona,7341018,NA,73,147,184,220,294,367
AR,05,Arkansas,3041878,NA,30,61,76,91,122,152
CA,06,California,38886551,NA,389,778,972,1167,1555,1944
CO,08,Colorado,5803748,NA,58,116,145,174,232,290
CT,09,Connecticut,3621089,NA,36,72,91,109,145,181
DE,10,Delaware,1014872,NA,10,20,25,30,41,51
DC,11,District of Columbia,668576,NA,7,13,17,20,27,33
FL,12,Florida,22183852,NA,222,444,555,666,887,1109
GA,13,Georgia,10855454,NA,109,217,271,326,434,543
HI,15,Hawaii,1398977,NA,14,28,35,42,56,70
ID,16,Idaho,1935521,NA,19,39,48,58,77,97
IL,17,Illinois,12563096,NA,126,251,314,377,503,628
IN,18,Indiana,6831941,NA,68,137,171,205,273,342
IA,19,Iowa,3200224,NA,32,64,80,96,128,160
KS,20,Kansas,2916451,NA,29,58,73,87,117,146
KY,21,Kentucky,4494379,NA,45,90,112,135,180,225
LA,22,Louisiana,4575074,NA,46,92,114,137,183,229
ME,23,Maine,1384543,NA,14,28,35,42,55,69
MD,24,Maryland,6133130,NA,61,123,153,184,245,307
MA,25,Massachusetts,6978662,NA,70,140,174,209,279,349
MI,26,Michigan,10032075,NA,100,201,251,301,401,502
MN,27,Minnesota,5716548,NA,57,114,143,171,229,286
MS,28,Mississippi,2927305,NA,29,59,73,88,117,146
MO,29,Missouri,6164537,NA,62,123,154,185,247,308
MT,30,Montana,1119563,NA,11,22,28,34,45,56
NE,31,Nebraska,1961505,NA,20,39,49,59,78,98
NV,32,Nevada,3165539,NA,32,63,79,95,127,158
NH,33,New Hampshire,1394692,NA,14,28,35,42,56,70
NJ,34,New Jersey,9254137,NA,93,185,231,278,370,463
NM,35,New Mexico,2100079,NA,21,42,53,63,84,105
NY,36,New York,19657190,NA,197,393,491,590,786,983
NC,37,North Carolina,10596562,NA,106,212,265,318,424,530
ND,38,North Dakota,772061,NA,8,15,19,23,31,39
OH,39,Ohio,11749303,NA,117,235,294,352,470,587
OK,40,Oklahoma,4001266,NA,40,80,100,120,160,200
OR,41,Oregon,4238665,NA,42,85,106,127,170,212
PA,42,Pennsylvania,12969276,NA,130,259,324,389,519,648
RI,44,Rhode Island,1090390,NA,11,22,27,33,44,55
SC,45,South Carolina,5246039,NA,52,105,131,157,210,262
SD,46,South Dakota,906458,NA,9,18,23,27,36,45
TN,47,Tennessee,7030607,NA,70,141,176,211,281,352
TX,48,Texas,29914599,NA,299,598,748,897,1197,1496
UT,49,Utah,3376238,NA,34,68,84,101,135,169
VT,50,Vermont,646910,NA,6,13,16,19,26,32
VA,51,Virginia,8583866,NA,86,172,215,258,343,429
WA,53,Washington,7735834,NA,77,155,193,232,309,387
WV,54,West Virginia,1774977,NA,18,35,44,53,71,89
WI,55,Wisconsin,5891022,NA,59,118,147,177,236,295
WY,56,Wyoming,578583,NA,6,12,14,17,23,29
PR,72,Puerto Rico,3221789,NA,32,64,81,97,129,161
7 changes: 7 additions & 0 deletions tests/testthat/fixtures/setup_flusight_fixtures.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# location metadata, 2022-23
# location_meta_23 <-read.csv("https://raw.githubusercontent.com/cdcepi/Flusight-forecast-data/master/data-locations/locations.csv")
# readr::write_csv(location_meta_23, "tests/testthat/fixtures/location_meta_23.csv")

# location metadata, 2023-24
location_meta_24 <- readr::read_csv("https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/refs/heads/main/auxiliary-data/locations.csv")
readr::write_csv(location_meta_24, "tests/testthat/fixtures/location_meta_24.csv")
218 changes: 218 additions & 0 deletions tests/testthat/test-get_flusight_bin_endpoints.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
test_that("get_flusight_bin_endpoints works, 2023/24", {
# definitions at
# https://github.com/cdcepi/FluSight-forecast-hub/tree/main/model-output#rate-trend-forecast-specifications
#
# our strategy is to:
# - construct data that should fall into known categories
# (with every horizon/category combination, and both criteria for stable)
# - compute bins and apply them to the data
# - check that we got the right answers
location_meta <- readr::read_csv(
file = testthat::test_path("fixtures", "location_meta_24.csv")
) |>
dplyr::mutate(
pop100k = .data[["population"]] / 100000
)

# locations for testing "stable", "increase" and "decrease" thresholds are:
# 56 = Wyoming, pop100k = 5.78
# 02 = Alaska, pop100k = 7.11 and 11 = District of Columbia, pop100k = 6.69
# these states trigger the "minimum count change at least 10" rule
locs <- c("US", "56", "02", "11", "05", "06")

# create data
# our reference date will be 2023-10-21.
# changes are relative to 2023-10-14.
# 2023-10-07 is throw-away, to make sure we grab the right "relative to" date
target_data <- tidyr::expand_grid(
location = locs,
date = as.Date("2023-10-07") + seq(from = 0, by = 7, length.out = 6),
value = NA
)

expected_categories <- NULL

# all expected category levels are "stable": rate change less than
# 1, 1, 2, or 2.5 * population rate
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[1]]
target_data$value[target_data$location == locs[1]] <- c(
0,
10000,
10000 + floor(0.99 * loc_pop100k),
10000 - floor(0.99 * loc_pop100k),
10000 + floor(1.99 * loc_pop100k),
10000 - floor(2.49 * loc_pop100k)
)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[1],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "stable"
)
)

# all expected category levels are "stable": count change less than 10
target_data$value[target_data$location == locs[2]] <- c(0, 300, 304, 297, 309, 291)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[2],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "stable"
)
)

# all expected category levels are "increase": count change >= 10,
# horizon 0: 1 <= rate change < 2
# horizon 1: 1 <= rate change < 3
# horizon 2: 2 <= rate change < 4
# horizon 3: 2.5 <= rate change < 5
# note, loc_pop100k for this location is 7.11 < 10
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[3]]
target_data$value[target_data$location == locs[3]] <- c(
0,
10000,
10000 + 10,
10000 + floor(2.99 * loc_pop100k),
10000 + floor(3.99 * loc_pop100k),
10000 + ceiling(2.51 * loc_pop100k)
)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[3],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "increase"
)
)

# all expected category levels are "decrease": count change <= -10,
# horizon 0: -1 >= rate change > -2
# horizon 1: -1 >= rate change > -3
# horizon 2: -2 >= rate change > -4
# horizon 3: -2.5 >= rate change > -5
# note, loc_pop100k for this location is 73.4 >= 10
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[4]]
target_data$value[target_data$location == locs[4]] <- c(
0,
10000,
10000 - 10,
10000 - floor(2.99 * loc_pop100k),
10000 - floor(3.99 * loc_pop100k),
10000 - ceiling(2.51 * loc_pop100k)
)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[4],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "decrease"
)
)

# all expected category levels are "large increase": count change >= 10,
# horizon 0: 2 <= rate change
# horizon 1: 3 <= rate change
# horizon 2: 4 <= rate change
# horizon 3: 5 <= rate change
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[5]]
target_data$value[target_data$location == locs[5]] <- c(
0,
10000,
10000 + max(10, ceiling(2 * loc_pop100k)),
10000 + max(10, ceiling(3 * loc_pop100k)),
10000 + max(10, ceiling(4 * loc_pop100k)),
10000 + max(10, ceiling(5 * loc_pop100k))
)

expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[5],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "large_increase"
)
)

# all expected category levels are "large decrease": count change <= -10,
# horizon 0: -2 >= rate change
# horizon 1: -3 >= rate change
# horizon 2: -4 >= rate change
# horizon 3: -5 >= rate change
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[6]]
target_data$value[target_data$location == locs[6]] <- c(
0,
10000,
10000 - max(10, ceiling(2 * loc_pop100k)),
10000 - max(10, ceiling(3 * loc_pop100k)),
10000 - max(10, ceiling(4 * loc_pop100k)),
10000 - max(10, ceiling(5 * loc_pop100k))
)

expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[6],
.data[["date"]] >= "2023-10-21"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7),
output_type_id = "large_decrease"
)
)

bin_endpoints <- get_flusight_bin_endpoints(
target_ts = target_data |>
dplyr::filter(
.data[["date"]] < "2023-10-21"
),
location_meta = location_meta,
season = "2023/24"
)

actual_categories <- bin_endpoints |>
dplyr::mutate(
reference_date = as.Date("2023-10-21"),
target_end_date = as.Date("2023-10-21") + 7 * .data[["horizon"]]
) |>
dplyr::left_join(
target_data,
by = c("location", "target_end_date" = "date")
) |>
dplyr::filter(
.data[["lower"]] < .data[["value"]],
.data[["value"]] <= .data[["upper"]]
)

mismatched_categorizations <- expected_categories |>
dplyr::left_join(
actual_categories,
by = c("location", "date" = "reference_date", "horizon")
) |>
dplyr::filter(output_type_id.x != output_type_id.y)

# expect no mismatches!
expect_equal(nrow(mismatched_categorizations), 0L)
})

0 comments on commit faa70dd

Please sign in to comment.