-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updated function and tests for get_flusight_bin_endpoints
- Loading branch information
Showing
4 changed files
with
292 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
abbreviation,location,location_name,population,...5,count_rate1,count_rate2,count_rate2p5,count_rate3,count_rate4,count_rate5 | ||
US,US,US,332200066,NA,3322,6644,8305,9966,13288,16610 | ||
AL,01,Alabama,5063778,NA,51,101,127,152,203,253 | ||
AK,02,Alaska,711426,NA,7,14,18,21,28,36 | ||
AZ,04,Arizona,7341018,NA,73,147,184,220,294,367 | ||
AR,05,Arkansas,3041878,NA,30,61,76,91,122,152 | ||
CA,06,California,38886551,NA,389,778,972,1167,1555,1944 | ||
CO,08,Colorado,5803748,NA,58,116,145,174,232,290 | ||
CT,09,Connecticut,3621089,NA,36,72,91,109,145,181 | ||
DE,10,Delaware,1014872,NA,10,20,25,30,41,51 | ||
DC,11,District of Columbia,668576,NA,7,13,17,20,27,33 | ||
FL,12,Florida,22183852,NA,222,444,555,666,887,1109 | ||
GA,13,Georgia,10855454,NA,109,217,271,326,434,543 | ||
HI,15,Hawaii,1398977,NA,14,28,35,42,56,70 | ||
ID,16,Idaho,1935521,NA,19,39,48,58,77,97 | ||
IL,17,Illinois,12563096,NA,126,251,314,377,503,628 | ||
IN,18,Indiana,6831941,NA,68,137,171,205,273,342 | ||
IA,19,Iowa,3200224,NA,32,64,80,96,128,160 | ||
KS,20,Kansas,2916451,NA,29,58,73,87,117,146 | ||
KY,21,Kentucky,4494379,NA,45,90,112,135,180,225 | ||
LA,22,Louisiana,4575074,NA,46,92,114,137,183,229 | ||
ME,23,Maine,1384543,NA,14,28,35,42,55,69 | ||
MD,24,Maryland,6133130,NA,61,123,153,184,245,307 | ||
MA,25,Massachusetts,6978662,NA,70,140,174,209,279,349 | ||
MI,26,Michigan,10032075,NA,100,201,251,301,401,502 | ||
MN,27,Minnesota,5716548,NA,57,114,143,171,229,286 | ||
MS,28,Mississippi,2927305,NA,29,59,73,88,117,146 | ||
MO,29,Missouri,6164537,NA,62,123,154,185,247,308 | ||
MT,30,Montana,1119563,NA,11,22,28,34,45,56 | ||
NE,31,Nebraska,1961505,NA,20,39,49,59,78,98 | ||
NV,32,Nevada,3165539,NA,32,63,79,95,127,158 | ||
NH,33,New Hampshire,1394692,NA,14,28,35,42,56,70 | ||
NJ,34,New Jersey,9254137,NA,93,185,231,278,370,463 | ||
NM,35,New Mexico,2100079,NA,21,42,53,63,84,105 | ||
NY,36,New York,19657190,NA,197,393,491,590,786,983 | ||
NC,37,North Carolina,10596562,NA,106,212,265,318,424,530 | ||
ND,38,North Dakota,772061,NA,8,15,19,23,31,39 | ||
OH,39,Ohio,11749303,NA,117,235,294,352,470,587 | ||
OK,40,Oklahoma,4001266,NA,40,80,100,120,160,200 | ||
OR,41,Oregon,4238665,NA,42,85,106,127,170,212 | ||
PA,42,Pennsylvania,12969276,NA,130,259,324,389,519,648 | ||
RI,44,Rhode Island,1090390,NA,11,22,27,33,44,55 | ||
SC,45,South Carolina,5246039,NA,52,105,131,157,210,262 | ||
SD,46,South Dakota,906458,NA,9,18,23,27,36,45 | ||
TN,47,Tennessee,7030607,NA,70,141,176,211,281,352 | ||
TX,48,Texas,29914599,NA,299,598,748,897,1197,1496 | ||
UT,49,Utah,3376238,NA,34,68,84,101,135,169 | ||
VT,50,Vermont,646910,NA,6,13,16,19,26,32 | ||
VA,51,Virginia,8583866,NA,86,172,215,258,343,429 | ||
WA,53,Washington,7735834,NA,77,155,193,232,309,387 | ||
WV,54,West Virginia,1774977,NA,18,35,44,53,71,89 | ||
WI,55,Wisconsin,5891022,NA,59,118,147,177,236,295 | ||
WY,56,Wyoming,578583,NA,6,12,14,17,23,29 | ||
PR,72,Puerto Rico,3221789,NA,32,64,81,97,129,161 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# location metadata, 2022-23 | ||
# location_meta_23 <-read.csv("https://raw.githubusercontent.com/cdcepi/Flusight-forecast-data/master/data-locations/locations.csv") | ||
# readr::write_csv(location_meta_23, "tests/testthat/fixtures/location_meta_23.csv") | ||
|
||
# location metadata, 2023-24 | ||
location_meta_24 <- readr::read_csv("https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/refs/heads/main/auxiliary-data/locations.csv") | ||
readr::write_csv(location_meta_24, "tests/testthat/fixtures/location_meta_24.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
test_that("get_flusight_bin_endpoints works, 2023/24", { | ||
# definitions at | ||
# https://github.com/cdcepi/FluSight-forecast-hub/tree/main/model-output#rate-trend-forecast-specifications | ||
# | ||
# our strategy is to: | ||
# - construct data that should fall into known categories | ||
# (with every horizon/category combination, and both criteria for stable) | ||
# - compute bins and apply them to the data | ||
# - check that we got the right answers | ||
location_meta <- readr::read_csv( | ||
file = testthat::test_path("fixtures", "location_meta_24.csv") | ||
) |> | ||
dplyr::mutate( | ||
pop100k = .data[["population"]] / 100000 | ||
) | ||
|
||
# locations for testing "stable", "increase" and "decrease" thresholds are: | ||
# 56 = Wyoming, pop100k = 5.78 | ||
# 02 = Alaska, pop100k = 7.11 and 11 = District of Columbia, pop100k = 6.69 | ||
# these states trigger the "minimum count change at least 10" rule | ||
locs <- c("US", "56", "02", "11", "05", "06") | ||
|
||
# create data | ||
# our reference date will be 2023-10-21. | ||
# changes are relative to 2023-10-14. | ||
# 2023-10-07 is throw-away, to make sure we grab the right "relative to" date | ||
target_data <- tidyr::expand_grid( | ||
location = locs, | ||
date = as.Date("2023-10-07") + seq(from = 0, by = 7, length.out = 6), | ||
value = NA | ||
) | ||
|
||
expected_categories <- NULL | ||
|
||
# all expected category levels are "stable": rate change less than | ||
# 1, 1, 2, or 2.5 * population rate | ||
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[1]] | ||
target_data$value[target_data$location == locs[1]] <- c( | ||
0, | ||
10000, | ||
10000 + floor(0.99 * loc_pop100k), | ||
10000 - floor(0.99 * loc_pop100k), | ||
10000 + floor(1.99 * loc_pop100k), | ||
10000 - floor(2.49 * loc_pop100k) | ||
) | ||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[1], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "stable" | ||
) | ||
) | ||
|
||
# all expected category levels are "stable": count change less than 10 | ||
target_data$value[target_data$location == locs[2]] <- c(0, 300, 304, 297, 309, 291) | ||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[2], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "stable" | ||
) | ||
) | ||
|
||
# all expected category levels are "increase": count change >= 10, | ||
# horizon 0: 1 <= rate change < 2 | ||
# horizon 1: 1 <= rate change < 3 | ||
# horizon 2: 2 <= rate change < 4 | ||
# horizon 3: 2.5 <= rate change < 5 | ||
# note, loc_pop100k for this location is 7.11 < 10 | ||
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[3]] | ||
target_data$value[target_data$location == locs[3]] <- c( | ||
0, | ||
10000, | ||
10000 + 10, | ||
10000 + floor(2.99 * loc_pop100k), | ||
10000 + floor(3.99 * loc_pop100k), | ||
10000 + ceiling(2.51 * loc_pop100k) | ||
) | ||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[3], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "increase" | ||
) | ||
) | ||
|
||
# all expected category levels are "decrease": count change <= -10, | ||
# horizon 0: -1 >= rate change > -2 | ||
# horizon 1: -1 >= rate change > -3 | ||
# horizon 2: -2 >= rate change > -4 | ||
# horizon 3: -2.5 >= rate change > -5 | ||
# note, loc_pop100k for this location is 73.4 >= 10 | ||
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[4]] | ||
target_data$value[target_data$location == locs[4]] <- c( | ||
0, | ||
10000, | ||
10000 - 10, | ||
10000 - floor(2.99 * loc_pop100k), | ||
10000 - floor(3.99 * loc_pop100k), | ||
10000 - ceiling(2.51 * loc_pop100k) | ||
) | ||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[4], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "decrease" | ||
) | ||
) | ||
|
||
# all expected category levels are "large increase": count change >= 10, | ||
# horizon 0: 2 <= rate change | ||
# horizon 1: 3 <= rate change | ||
# horizon 2: 4 <= rate change | ||
# horizon 3: 5 <= rate change | ||
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[5]] | ||
target_data$value[target_data$location == locs[5]] <- c( | ||
0, | ||
10000, | ||
10000 + max(10, ceiling(2 * loc_pop100k)), | ||
10000 + max(10, ceiling(3 * loc_pop100k)), | ||
10000 + max(10, ceiling(4 * loc_pop100k)), | ||
10000 + max(10, ceiling(5 * loc_pop100k)) | ||
) | ||
|
||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[5], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "large_increase" | ||
) | ||
) | ||
|
||
# all expected category levels are "large decrease": count change <= -10, | ||
# horizon 0: -2 >= rate change | ||
# horizon 1: -3 >= rate change | ||
# horizon 2: -4 >= rate change | ||
# horizon 3: -5 >= rate change | ||
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[6]] | ||
target_data$value[target_data$location == locs[6]] <- c( | ||
0, | ||
10000, | ||
10000 - max(10, ceiling(2 * loc_pop100k)), | ||
10000 - max(10, ceiling(3 * loc_pop100k)), | ||
10000 - max(10, ceiling(4 * loc_pop100k)), | ||
10000 - max(10, ceiling(5 * loc_pop100k)) | ||
) | ||
|
||
expected_categories <- dplyr::bind_rows( | ||
expected_categories, | ||
target_data |> | ||
dplyr::filter( | ||
.data[["location"]] == locs[6], | ||
.data[["date"]] >= "2023-10-21" | ||
) |> | ||
dplyr::mutate( | ||
horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), | ||
output_type_id = "large_decrease" | ||
) | ||
) | ||
|
||
bin_endpoints <- get_flusight_bin_endpoints( | ||
target_ts = target_data |> | ||
dplyr::filter( | ||
.data[["date"]] < "2023-10-21" | ||
), | ||
location_meta = location_meta, | ||
season = "2023/24" | ||
) | ||
|
||
actual_categories <- bin_endpoints |> | ||
dplyr::mutate( | ||
reference_date = as.Date("2023-10-21"), | ||
target_end_date = as.Date("2023-10-21") + 7 * .data[["horizon"]] | ||
) |> | ||
dplyr::left_join( | ||
target_data, | ||
by = c("location", "target_end_date" = "date") | ||
) |> | ||
dplyr::filter( | ||
.data[["lower"]] < .data[["value"]], | ||
.data[["value"]] <= .data[["upper"]] | ||
) | ||
|
||
mismatched_categorizations <- expected_categories |> | ||
dplyr::left_join( | ||
actual_categories, | ||
by = c("location", "date" = "reference_date", "horizon") | ||
) |> | ||
dplyr::filter(output_type_id.x != output_type_id.y) | ||
|
||
# expect no mismatches! | ||
expect_equal(nrow(mismatched_categorizations), 0L) | ||
}) |