From faa70ddfeae7404c158f1712fe76b5407d3a1457 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 25 Sep 2024 18:10:20 -0400 Subject: [PATCH] updated function and tests for get_flusight_bin_endpoints --- R/get_flusight_bin_endpoints.R | 34 ++- tests/testthat/fixtures/location_meta_24.csv | 54 +++++ .../fixtures/setup_flusight_fixtures.R | 7 + .../test-get_flusight_bin_endpoints.R | 218 ++++++++++++++++++ 4 files changed, 292 insertions(+), 21 deletions(-) create mode 100644 tests/testthat/fixtures/location_meta_24.csv create mode 100644 tests/testthat/fixtures/setup_flusight_fixtures.R create mode 100644 tests/testthat/test-get_flusight_bin_endpoints.R diff --git a/R/get_flusight_bin_endpoints.R b/R/get_flusight_bin_endpoints.R index c42d6a0..f67288f 100644 --- a/R/get_flusight_bin_endpoints.R +++ b/R/get_flusight_bin_endpoints.R @@ -4,10 +4,10 @@ #' including the columns `location`, `date`, and `value` #' @param location_meta Data frame with metadata about locations for FluSight, #' matching the format of the file in cdcepi/FluSight-forecast-hub/auxiliary-data/locations.csv -#' @param season String naming the season: either "2022/23" or "2023/24" +#' @param season String naming the season: only "2023/24" is supported #' #' @details Compute the bin endpoints used for categorical targets in FluSight -#' in the 2022/23 or 2023/24 season. In both seasons, there were 5 categories: +#' in the 2023/24 season. In that season, there were 5 categories: #' "large decrease", "decrease", "stable", "increase", and "large increase". #' The bin endpoints have the form #' value +/- max(multiplier * population / 100k, min_count), @@ -56,36 +56,28 @@ get_flusight_bin_endpoints <- function(target_ts, location_meta, season) { } get_flusight_bin_endpoint_meta <- function(season) { - if (season == "2022/23") { - return(get_flusight_bin_endpoint_meta_2223()) - } else if (season == "2023/24") { + if (season == "2023/24") { return(get_flusight_bin_endpoint_meta_2324()) } else { stop("unsupported season") } } -get_flusight_bin_endpoint_meta_2223 <- function() { - bin_endpoint_meta <- data.frame( - output_type_id = c("large_decrease", "decrease", "stable", "increase", "large_increase"), - lower_sign = c(-1, -1, -1, 1, 1), - lower_rate_multiplier = c(Inf, 2, 1, 1, 2), - lower_min_count_change = c(Inf, 40, 20, 20, 40), - upper_sign = c(-1, -1, 1, 1, 1), - upper_rate_multiplier = c(2, 1, 1, 2, Inf), - upper_min_count_change = c(40, 20, 20, 40, Inf), - stringsAsFactors = FALSE - ) - - return(bin_endpoint_meta) -} get_flusight_bin_endpoint_meta_2324 <- function() { + # we use 9.5 for the minimum count change because for low-population states, + # the intervals have the half-open form (value - 9.5, value + 9.5] for stable, + # (value - ***, value - 9.5] for decrease and large decrease, + # (value + 9.5, value + ***] for increase and large increase, + # where *** is something coming from a population rate per 100k since all + # state populations / 100000 are greater than 5 + # In all cases, this says a count change of less than 10 is stable + # and a count change of 10 or more is non-stable bin_endpoint_meta <- tidyr::expand_grid( horizon = 0:3, output_type_id = list(c("large_decrease", "decrease", "stable", "increase", "large_increase")), - lower_min_count_change = 10, - upper_min_count_change = 10 + lower_min_count_change = 9.5, + upper_min_count_change = 9.5 ) bin_endpoint_meta$rate_multiplier_endpoints <- list( diff --git a/tests/testthat/fixtures/location_meta_24.csv b/tests/testthat/fixtures/location_meta_24.csv new file mode 100644 index 0000000..13ab868 --- /dev/null +++ b/tests/testthat/fixtures/location_meta_24.csv @@ -0,0 +1,54 @@ +abbreviation,location,location_name,population,...5,count_rate1,count_rate2,count_rate2p5,count_rate3,count_rate4,count_rate5 +US,US,US,332200066,NA,3322,6644,8305,9966,13288,16610 +AL,01,Alabama,5063778,NA,51,101,127,152,203,253 +AK,02,Alaska,711426,NA,7,14,18,21,28,36 +AZ,04,Arizona,7341018,NA,73,147,184,220,294,367 +AR,05,Arkansas,3041878,NA,30,61,76,91,122,152 +CA,06,California,38886551,NA,389,778,972,1167,1555,1944 +CO,08,Colorado,5803748,NA,58,116,145,174,232,290 +CT,09,Connecticut,3621089,NA,36,72,91,109,145,181 +DE,10,Delaware,1014872,NA,10,20,25,30,41,51 +DC,11,District of Columbia,668576,NA,7,13,17,20,27,33 +FL,12,Florida,22183852,NA,222,444,555,666,887,1109 +GA,13,Georgia,10855454,NA,109,217,271,326,434,543 +HI,15,Hawaii,1398977,NA,14,28,35,42,56,70 +ID,16,Idaho,1935521,NA,19,39,48,58,77,97 +IL,17,Illinois,12563096,NA,126,251,314,377,503,628 +IN,18,Indiana,6831941,NA,68,137,171,205,273,342 +IA,19,Iowa,3200224,NA,32,64,80,96,128,160 +KS,20,Kansas,2916451,NA,29,58,73,87,117,146 +KY,21,Kentucky,4494379,NA,45,90,112,135,180,225 +LA,22,Louisiana,4575074,NA,46,92,114,137,183,229 +ME,23,Maine,1384543,NA,14,28,35,42,55,69 +MD,24,Maryland,6133130,NA,61,123,153,184,245,307 +MA,25,Massachusetts,6978662,NA,70,140,174,209,279,349 +MI,26,Michigan,10032075,NA,100,201,251,301,401,502 +MN,27,Minnesota,5716548,NA,57,114,143,171,229,286 +MS,28,Mississippi,2927305,NA,29,59,73,88,117,146 +MO,29,Missouri,6164537,NA,62,123,154,185,247,308 +MT,30,Montana,1119563,NA,11,22,28,34,45,56 +NE,31,Nebraska,1961505,NA,20,39,49,59,78,98 +NV,32,Nevada,3165539,NA,32,63,79,95,127,158 +NH,33,New Hampshire,1394692,NA,14,28,35,42,56,70 +NJ,34,New Jersey,9254137,NA,93,185,231,278,370,463 +NM,35,New Mexico,2100079,NA,21,42,53,63,84,105 +NY,36,New York,19657190,NA,197,393,491,590,786,983 +NC,37,North Carolina,10596562,NA,106,212,265,318,424,530 +ND,38,North Dakota,772061,NA,8,15,19,23,31,39 +OH,39,Ohio,11749303,NA,117,235,294,352,470,587 +OK,40,Oklahoma,4001266,NA,40,80,100,120,160,200 +OR,41,Oregon,4238665,NA,42,85,106,127,170,212 +PA,42,Pennsylvania,12969276,NA,130,259,324,389,519,648 +RI,44,Rhode Island,1090390,NA,11,22,27,33,44,55 +SC,45,South Carolina,5246039,NA,52,105,131,157,210,262 +SD,46,South Dakota,906458,NA,9,18,23,27,36,45 +TN,47,Tennessee,7030607,NA,70,141,176,211,281,352 +TX,48,Texas,29914599,NA,299,598,748,897,1197,1496 +UT,49,Utah,3376238,NA,34,68,84,101,135,169 +VT,50,Vermont,646910,NA,6,13,16,19,26,32 +VA,51,Virginia,8583866,NA,86,172,215,258,343,429 +WA,53,Washington,7735834,NA,77,155,193,232,309,387 +WV,54,West Virginia,1774977,NA,18,35,44,53,71,89 +WI,55,Wisconsin,5891022,NA,59,118,147,177,236,295 +WY,56,Wyoming,578583,NA,6,12,14,17,23,29 +PR,72,Puerto Rico,3221789,NA,32,64,81,97,129,161 diff --git a/tests/testthat/fixtures/setup_flusight_fixtures.R b/tests/testthat/fixtures/setup_flusight_fixtures.R new file mode 100644 index 0000000..2ad080d --- /dev/null +++ b/tests/testthat/fixtures/setup_flusight_fixtures.R @@ -0,0 +1,7 @@ +# location metadata, 2022-23 +# location_meta_23 <-read.csv("https://raw.githubusercontent.com/cdcepi/Flusight-forecast-data/master/data-locations/locations.csv") +# readr::write_csv(location_meta_23, "tests/testthat/fixtures/location_meta_23.csv") + +# location metadata, 2023-24 +location_meta_24 <- readr::read_csv("https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/refs/heads/main/auxiliary-data/locations.csv") +readr::write_csv(location_meta_24, "tests/testthat/fixtures/location_meta_24.csv") diff --git a/tests/testthat/test-get_flusight_bin_endpoints.R b/tests/testthat/test-get_flusight_bin_endpoints.R new file mode 100644 index 0000000..3a25783 --- /dev/null +++ b/tests/testthat/test-get_flusight_bin_endpoints.R @@ -0,0 +1,218 @@ +test_that("get_flusight_bin_endpoints works, 2023/24", { + # definitions at + # https://github.com/cdcepi/FluSight-forecast-hub/tree/main/model-output#rate-trend-forecast-specifications + # + # our strategy is to: + # - construct data that should fall into known categories + # (with every horizon/category combination, and both criteria for stable) + # - compute bins and apply them to the data + # - check that we got the right answers + location_meta <- readr::read_csv( + file = testthat::test_path("fixtures", "location_meta_24.csv") + ) |> + dplyr::mutate( + pop100k = .data[["population"]] / 100000 + ) + + # locations for testing "stable", "increase" and "decrease" thresholds are: + # 56 = Wyoming, pop100k = 5.78 + # 02 = Alaska, pop100k = 7.11 and 11 = District of Columbia, pop100k = 6.69 + # these states trigger the "minimum count change at least 10" rule + locs <- c("US", "56", "02", "11", "05", "06") + + # create data + # our reference date will be 2023-10-21. + # changes are relative to 2023-10-14. + # 2023-10-07 is throw-away, to make sure we grab the right "relative to" date + target_data <- tidyr::expand_grid( + location = locs, + date = as.Date("2023-10-07") + seq(from = 0, by = 7, length.out = 6), + value = NA + ) + + expected_categories <- NULL + + # all expected category levels are "stable": rate change less than + # 1, 1, 2, or 2.5 * population rate + loc_pop100k <- location_meta$pop100k[location_meta$location == locs[1]] + target_data$value[target_data$location == locs[1]] <- c( + 0, + 10000, + 10000 + floor(0.99 * loc_pop100k), + 10000 - floor(0.99 * loc_pop100k), + 10000 + floor(1.99 * loc_pop100k), + 10000 - floor(2.49 * loc_pop100k) + ) + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[1], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "stable" + ) + ) + + # all expected category levels are "stable": count change less than 10 + target_data$value[target_data$location == locs[2]] <- c(0, 300, 304, 297, 309, 291) + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[2], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "stable" + ) + ) + + # all expected category levels are "increase": count change >= 10, + # horizon 0: 1 <= rate change < 2 + # horizon 1: 1 <= rate change < 3 + # horizon 2: 2 <= rate change < 4 + # horizon 3: 2.5 <= rate change < 5 + # note, loc_pop100k for this location is 7.11 < 10 + loc_pop100k <- location_meta$pop100k[location_meta$location == locs[3]] + target_data$value[target_data$location == locs[3]] <- c( + 0, + 10000, + 10000 + 10, + 10000 + floor(2.99 * loc_pop100k), + 10000 + floor(3.99 * loc_pop100k), + 10000 + ceiling(2.51 * loc_pop100k) + ) + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[3], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "increase" + ) + ) + + # all expected category levels are "decrease": count change <= -10, + # horizon 0: -1 >= rate change > -2 + # horizon 1: -1 >= rate change > -3 + # horizon 2: -2 >= rate change > -4 + # horizon 3: -2.5 >= rate change > -5 + # note, loc_pop100k for this location is 73.4 >= 10 + loc_pop100k <- location_meta$pop100k[location_meta$location == locs[4]] + target_data$value[target_data$location == locs[4]] <- c( + 0, + 10000, + 10000 - 10, + 10000 - floor(2.99 * loc_pop100k), + 10000 - floor(3.99 * loc_pop100k), + 10000 - ceiling(2.51 * loc_pop100k) + ) + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[4], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "decrease" + ) + ) + + # all expected category levels are "large increase": count change >= 10, + # horizon 0: 2 <= rate change + # horizon 1: 3 <= rate change + # horizon 2: 4 <= rate change + # horizon 3: 5 <= rate change + loc_pop100k <- location_meta$pop100k[location_meta$location == locs[5]] + target_data$value[target_data$location == locs[5]] <- c( + 0, + 10000, + 10000 + max(10, ceiling(2 * loc_pop100k)), + 10000 + max(10, ceiling(3 * loc_pop100k)), + 10000 + max(10, ceiling(4 * loc_pop100k)), + 10000 + max(10, ceiling(5 * loc_pop100k)) + ) + + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[5], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "large_increase" + ) + ) + + # all expected category levels are "large decrease": count change <= -10, + # horizon 0: -2 >= rate change + # horizon 1: -3 >= rate change + # horizon 2: -4 >= rate change + # horizon 3: -5 >= rate change + loc_pop100k <- location_meta$pop100k[location_meta$location == locs[6]] + target_data$value[target_data$location == locs[6]] <- c( + 0, + 10000, + 10000 - max(10, ceiling(2 * loc_pop100k)), + 10000 - max(10, ceiling(3 * loc_pop100k)), + 10000 - max(10, ceiling(4 * loc_pop100k)), + 10000 - max(10, ceiling(5 * loc_pop100k)) + ) + + expected_categories <- dplyr::bind_rows( + expected_categories, + target_data |> + dplyr::filter( + .data[["location"]] == locs[6], + .data[["date"]] >= "2023-10-21" + ) |> + dplyr::mutate( + horizon = as.integer((.data[["date"]] - as.Date("2023-10-21")) / 7), + output_type_id = "large_decrease" + ) + ) + + bin_endpoints <- get_flusight_bin_endpoints( + target_ts = target_data |> + dplyr::filter( + .data[["date"]] < "2023-10-21" + ), + location_meta = location_meta, + season = "2023/24" + ) + + actual_categories <- bin_endpoints |> + dplyr::mutate( + reference_date = as.Date("2023-10-21"), + target_end_date = as.Date("2023-10-21") + 7 * .data[["horizon"]] + ) |> + dplyr::left_join( + target_data, + by = c("location", "target_end_date" = "date") + ) |> + dplyr::filter( + .data[["lower"]] < .data[["value"]], + .data[["value"]] <= .data[["upper"]] + ) + + mismatched_categorizations <- expected_categories |> + dplyr::left_join( + actual_categories, + by = c("location", "date" = "reference_date", "horizon") + ) |> + dplyr::filter(output_type_id.x != output_type_id.y) + + # expect no mismatches! + expect_equal(nrow(mismatched_categorizations), 0L) +})