Skip to content

Commit

Permalink
Merge pull request #206 from HopkinsIDD/inference_with_usa
Browse files Browse the repository at this point in the history
Inference with usa
  • Loading branch information
jcblemai authored Jun 6, 2024
2 parents 1dab9a9 + 6e302b7 commit ad733ca
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 105 deletions.
1 change: 0 additions & 1 deletion build/conda_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,6 @@ dependencies:
- conda-forge/linux-64::r-readxl==1.4.1=r42h3ebcfa7_1
- conda-forge/noarch::r-reprex==2.0.2=r42hc72bb7e_1
- conda-forge/linux-64::r-tidyr==1.3.0=r42h38f115c_0
- conda-forge/noarch::r-tigris==2.0.1=r42hc72bb7e_0
- conda-forge/noarch::r-waldo==0.4.0=r42hc72bb7e_1
- conda-forge/noarch::r-broom==1.0.3=r42hc72bb7e_0
- conda-forge/linux-64::r-gdtools==0.3.0=r42he0ce631_0
Expand Down
2 changes: 1 addition & 1 deletion build/local_install.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ local({r <- getOption("repos")

library(devtools)

install.packages(c("covidcast","data.table","vroom","dplyr"), quiet=TRUE, dependencies = TRUE)
install.packages(c("covidcast","data.table","vroom","dplyr"), quiet=TRUE)
# devtools::install_github("hrbrmstr/cdcfluview")

# To run if operating in the container -----
Expand Down
1 change: 0 additions & 1 deletion datasetup/build_US_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ state_level <- ifelse(!is.null(config$subpop_setup$state_level) && config$subpop
# tidycensus::census_api_key(key = census_key)



filterUSPS <- c("WY","VT","DC","AK","ND","SD","DE","MT","RI","ME","NH","HI","ID","WV","NE","NM",
"KS","NV","MS","AR","UT","IA","CT","OK","OR","KY","LA","AL","SC","MN","CO","WI",
"MD","MO","IN","TN","MA","AZ","WA","VA","NJ","MI","NC","GA","OH","IL","PA","NY","FL","TX","CA")
Expand Down
1 change: 0 additions & 1 deletion flepimop/R_packages/flepicommon/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export(fix_negative_counts_single_subpop)
export(get_CSSE_US_data)
export(get_CSSE_US_matchGlobal_data)
export(get_CSSE_global_data)
export(get_USAFacts_data)
export(get_covidcast_data)
export(get_groundtruth_from_source)
export(load_config)
Expand Down
58 changes: 0 additions & 58 deletions flepimop/R_packages/flepicommon/R/DataUtils.R
Original file line number Diff line number Diff line change
Expand Up @@ -426,64 +426,6 @@ aggregate_counties_to_state <- function(df, state_fips){
}


##'
##' Pull case and death count data from USAFacts
##'
##' Pulls the USAFacts cumulative case count and death data. Calculates incident counts.
##' USAFacts does not include data for all the territories (aka island areas). These data are pulled from NYTimes.
##'
##' Returned data preview:
##' tibble [352,466 × 7] (S3: grouped_df/tbl_df/tbl/data.frame)
##' $ FIPS : chr [1:352466] "00001" "00001" "00001" "00001" ...
##' $ source : chr [1:352466] "NY" "NY" "NY" "NY" ...
##' $ Update : Date[1:352466], format: "2020-01-22" "2020-01-23" ...
##' $ Confirmed : num [1:352466] 0 0 0 0 0 0 0 0 0 0 ...
##' $ Deaths : num [1:352466] 0 0 0 0 0 0 0 0 0 0 ...
##' $ incidI : num [1:352466] 0 0 0 0 0 0 0 0 0 0 ...
##' $ incidDeath : num [1:352466] 0 0 0 0 0 0 0 0 0 0 ...
##'
##' @param case_data_filename Filename where case data are stored
##' @param death_data_filename Filename where death data are stored
##' @param incl_unassigned Includes data unassigned to counties (default is FALSE)
##' @return the case and deaths data frame
##'
##'
##' @export
##'
get_USAFacts_data <- function(case_data_filename = "data/case_data/USAFacts_case_data.csv",
death_data_filename = "data/case_data/USAFacts_death_data.csv",
incl_unassigned = FALSE){

USAFACTS_CASE_DATA_URL <- "https://usafactsstatic.blob.core.windows.net/public/data/covid-19/covid_confirmed_usafacts.csv"
USAFACTS_DEATH_DATA_URL <- "https://usafactsstatic.blob.core.windows.net/public/data/covid-19/covid_deaths_usafacts.csv"
usafacts_case <- download_USAFacts_data(case_data_filename, USAFACTS_CASE_DATA_URL, "Confirmed", incl_unassigned)
usafacts_death <- download_USAFacts_data(death_data_filename, USAFACTS_DEATH_DATA_URL, "Deaths", incl_unassigned)

usafacts_data <- dplyr::full_join(usafacts_case, usafacts_death)
usafacts_data <- dplyr::select(usafacts_data, Update, source, FIPS, Confirmed, Deaths)
usafacts_data <- rbind(usafacts_data, get_islandareas_data()) # Append island areas
usafacts_data <- dplyr::arrange(usafacts_data, source, FIPS, Update)

# Create columns incidI and incidDeath
usafacts_data <- dplyr::group_modify(
dplyr::group_by(
usafacts_data,
FIPS
),
function(.x,.y){
.x$incidI = c(.x$Confirmed[1],diff(.x$Confirmed))
.x$incidDeath = c(.x$Deaths[1],diff(.x$Deaths,))
return(.x)
}
)

# Fix incidence counts that go negative and NA values or missing dates
usafacts_data <- fix_negative_counts(usafacts_data, "Confirmed", "incidI")
usafacts_data <- fix_negative_counts(usafacts_data, "Deaths", "incidDeath")

return(usafacts_data)
}



##'
Expand Down
30 changes: 30 additions & 0 deletions flepimop/R_packages/flepiconfig/R/process_npi_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ NULL



load_geodata_file <- function(filename,
geoid_len = 0,
geoid_pad = "0",
state_name = TRUE) {

if(!file.exists(filename)){stop(paste(filename,"does not exist in",getwd()))}
geodata <- readr::read_csv(filename) %>%
dplyr::mutate(geoid = as.character(geoid))

if (!("geoid" %in% names(geodata))) {
stop(paste(filename, "does not have a column named geoid"))
}

if (geoid_len > 0) {
geodata$geoid <- stringr::str_pad(geodata$geoid, geoid_len, pad = geoid_pad)
}

if(state_name) {
utils::data(fips_us_county, package = "flepicommon") # arrow::read_parquet("datasetup/usdata/fips_us_county.parquet")
geodata <- fips_us_county %>%
dplyr::distinct(state, state_name) %>%
dplyr::rename(USPS = state) %>%
dplyr::rename(state = state_name) %>%
dplyr::mutate(state = dplyr::recode(state, "U.S. Virgin Islands" = "Virgin Islands")) %>%
dplyr::right_join(geodata)
}

return(geodata)
}

##' find_truncnorm_mean_parameter
##'
##' Convenience function that estimates the mean value for a truncnorm distribution given a, b, and sd that will have the expected value of the input mean.
Expand Down
1 change: 1 addition & 0 deletions flepimop/main_scripts/create_seeding_added.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ library(purrr)

option_list <- list(
optparse::make_option(c("-c", "--config"), action = "store", default = Sys.getenv("CONFIG_PATH"), type = "character", help = "path to the config file"),
optparse::make_option(c("-p", "--flepi_path"), action="store", type='character', help="path to the flepiMoP directory", default = Sys.getenv("FLEPI_PATH", "flepiMoP/")),
optparse::make_option(c("-k", "--keep_all_seeding"), action="store",default=TRUE,type='logical',help="Whether to filter away seeding prior to the start date of the simulation.")
)
opt <- optparse::parse_args(optparse::OptionParser(option_list = option_list))
Expand Down
2 changes: 1 addition & 1 deletion flepimop/main_scripts/inference_slot.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ if (opt$config == ""){
}
config = flepicommon::load_config(opt$config)


if (!is.null(config$inference$incl_aggr_likelihood)){
print("Using config option for `incl_aggr_likelihood`.")
opt$incl_aggr_likelihood <- config$inference$incl_aggr_likelihood
Expand Down Expand Up @@ -275,6 +274,7 @@ if (config$inference$do_inference){
dplyr::right_join(tidyr::expand_grid(subpop = unique(.$subpop), date = unique(.$date))) %>%
dplyr::mutate_if(is.numeric, dplyr::coalesce, 0)


# add aggregate groundtruth to the obs data for the likelihood calc
if (opt$incl_aggr_likelihood){
obs <- obs %>%
Expand Down
84 changes: 42 additions & 42 deletions postprocessing/plot_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ center_line_var <- ifelse(point_est==0.5, "point", "point-mean")
proj_data <- data_comb


#### Which valid locations are missing from our submission?
#### Which valid locations are missing from our submission?

# locs <- read_csv("https://raw.githubusercontent.com/reichlab/covid19-forecast-hub/master/data-locations/locations.csv")
# mismatched <- unique(proj_data$location)[which(!(unique(proj_data$location) %in% locs$location))]
# missing_from_fc <- unique(locs$location)[which(!(locs$location %in% unique(proj_data$location)))]
#
# locs %>% filter(location %in% missing_from_fc)
#
# locs %>% filter(location %in% missing_from_fc)


# STATE DATA --------------------------------------------------------------
Expand All @@ -38,7 +38,7 @@ state_cw <- fips_us_county %>%

# GROUND TRUTH ------------------------------------------------------------

gt_data <- gt_data %>%
gt_data <- gt_data %>%
mutate(time = lubridate::as_date(time)) %>% mutate(date = time)
colnames(gt_data) <- gsub("incidI", "incidC", colnames(gt_data))
gt_outcomes <- outcomes_[outcomes_ != "I" & sapply(X = paste0("incid", outcomes_), FUN = function(x=X, y) any(grepl(pattern = x, x = y)), y = colnames(gt_data)) ]
Expand Down Expand Up @@ -67,16 +67,16 @@ if (any(outcomes_time_=="weekly")) {
mutate(agestrat="age0to130") %>%
rename(outcome = outcome_name, value = outcome) %>%
filter(outcome %in% paste0("incid", weekly_cum_outcomes_)),
obs_data = gt_data_2,
obs_data = gt_data_2,
gt_cum_vars = paste0("cum", outcomes_gt_[outcomes_cumfromgt_gt_]), # variables to get cum from GT
forecast_date = lubridate::as_date(forecast_date),
aggregation="week",
loc_column = "USPS",
loc_column = "USPS",
use_obs_data = use_obs_data_forcum) %>%
rename(outcome_name = outcome, outcome = value) %>%
select(-agestrat)
gt_data_st_week <- gt_data_st_week %>%

gt_data_st_week <- gt_data_st_week %>%
bind_rows(gt_data_st_weekcum)
}
gt_cl <- gt_cl %>% bind_rows(gt_data_st_week %>% mutate(time_aggr = "weekly"))
Expand All @@ -93,11 +93,11 @@ if (any(outcomes_time_=="daily")) {
mutate(agestrat="age0to130") %>%
rename(outcome = outcome_name, value = outcome) %>%
filter(outcome %in% paste0("incid", daily_cum_outcomes_)),
obs_data = gt_data_2,
obs_data = gt_data_2,
gt_cum_vars = paste0("cum", outcomes_gt_[outcomes_cumfromgt_gt_]), # variables to get cum from GT
forecast_date = lubridate::as_date(forecast_date),
aggregation="day",
loc_column = "USPS",
loc_column = "USPS",
use_obs_data = use_obs_data_forcum) %>%
rename(outcome_name = outcome, outcome = value) %>%
select(-agestrat)
Expand All @@ -117,12 +117,12 @@ gt_cl <- gt_cl %>% rename(date = time)
# inc_dat_st_vars <- inc_dat_st_vars %>% filter(date != max(date))
# }

dat_st_cl2 <- gt_cl %>%
dat_st_cl2 <- gt_cl %>%
select(date, USPS, target = outcome_name, time_aggr, value = outcome) %>%
mutate(incid_cum = ifelse(grepl("inc", target), "inc", "cum")) %>%
mutate(aggr_target = !grepl('_', target)) %>%
mutate(outcome = substr(gsub("cum|incid", "", target), 1,1)) %>%
mutate(pre_gt_end = date<=validation_date)
mutate(pre_gt_end = date<=validation_date)



Expand All @@ -134,7 +134,7 @@ dat_st_cl2 <- gt_cl %>%

forecast_st <- proj_data %>%
filter(nchar(location)==2 & (quantile %in% sort(unique(c(quant_values, 0.5))) | is.na(quantile))) %>%
left_join(state_cw, by = c("location"))
left_join(state_cw, by = c("location"))

# filter out incid or cum
if (!plot_incid) { forecast_st <- forecast_st %>% filter(!grepl(" inc ", target)) }
Expand All @@ -150,12 +150,12 @@ if(any(outcomes_cum_)){
forecast_st <- forecast_st %>% filter(grepl(paste0(c(paste0("inc ", outcomes_name), cum_outcomes_name), collapse = "|"), target))

# create cat variables
forecast_st_plt <- forecast_st %>%
forecast_st_plt <- forecast_st %>%
mutate(incid_cum = ifelse(grepl("inc ", target), "inc", "cum")) %>%
mutate(outcome = stringr::word(target, 5)) %>%
mutate(outcome = recode(outcome, "inf"="I", "case"="C", "hosp"="H", "death"="D")) %>%
dplyr::mutate(quantile_cln = ifelse(!is.na(quantile), paste0("q", paste0(as.character(quantile*100), "%")),
ifelse(type=="point-mean", paste0("mean"),
dplyr::mutate(quantile_cln = ifelse(!is.na(quantile), paste0("q", paste0(as.character(quantile*100), "%")),
ifelse(type=="point-mean", paste0("mean"),
ifelse(type=="point", paste0("median"), NA)))) %>%
mutate(target_type = paste0(incid_cum, outcome))

Expand All @@ -173,12 +173,12 @@ if(center_line == "mean"){
forecast_st_plt <- forecast_st_plt %>% mutate(quantile_cln = gsub("q50%", "ctr", quantile_cln))
}

forecast_st_plt <- forecast_st_plt %>%
forecast_st_plt <- forecast_st_plt %>%
select(scenario_name, scenario_id, target = target_type, incid_cum, outcome, date = target_end_date, USPS, quantile_cln, value) %>%
pivot_wider(names_from = quantile_cln, values_from = value) %>%
mutate(type = "projection") %>%
full_join(pltdat_truth %>%
mutate(type = "gt", scenario_name = ifelse(pre_gt_end, "gt-pre-projection", "gt-post-projection")) %>%
full_join(pltdat_truth %>%
mutate(type = "gt", scenario_name = ifelse(pre_gt_end, "gt-pre-projection", "gt-post-projection")) %>%
select(date, USPS, target = target_type, incid_cum, type, scenario_name, ctr=gt)) %>%
filter(date >= trunc_date & date <= sim_end_date)

Expand All @@ -201,15 +201,15 @@ stplot_fname_nosqrt <- paste0(stplot_fname, ".pdf")

pdf(stplot_fname_nosqrt, width=7, height=11)
for(usps in unique(forecast_st_plt$USPS)){

print(paste0("Plotting: ", usps))
cols_tmp <- cols[names(cols) %in% unique(forecast_st_plt$scenario_name)]

target_labs <- paste0(str_to_title(outcomes_time_[match(gsub("inc","",unique(forecast_st_plt$target)),outcomes_)]), " incident ", gsub("inc","",unique(forecast_st_plt$target)))
names(target_labs) <- unique(forecast_st_plt$target)
inc_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%

inc_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%
filter(incid_cum=="inc") %>%
mutate(scenario_name = factor(scenario_name)) %>%
ggplot(aes(x = date)) +
Expand All @@ -232,15 +232,15 @@ for(usps in unique(forecast_st_plt$USPS)){
theme(legend.position = "bottom", legend.text = element_text(size=10),
axis.text.x = element_text(size=6, angle = 45))
plot(inc_st_plt)


if (plot_cum) {

target_labs <- paste0(str_to_title(outcomes_time_[match(gsub("cum","",unique(forecast_st_plt$target)),outcomes_)]), " cumulative ", gsub("cum","",unique(forecast_st_plt$target)))
names(target_labs) <- unique(forecast_st_plt$target)
cum_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%

cum_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%
filter(incid_cum=="cum") %>%
mutate(scenario_name = factor(scenario_name)) %>%
ggplot(aes(x = date)) +
Expand All @@ -261,7 +261,7 @@ for(usps in unique(forecast_st_plt$USPS)){
labeller = as_labeller(target_labs)) +
theme(legend.position = "bottom", legend.text = element_text(size=10),
axis.text.x = element_text(size=6, angle = 45))

plot(cum_st_plt)
}
}
Expand All @@ -273,15 +273,15 @@ scale_y_funct <- scale_y_sqrt

pdf(stplot_fname_sqrt, width=7, height=11)
for(usps in unique(forecast_st_plt$USPS)){

print(paste0("Plotting: ", usps))
cols_tmp <- cols[names(cols) %in% unique(forecast_st_plt$scenario_name)]

target_labs <- paste0(str_to_title(outcomes_time_[match(gsub("inc","",unique(forecast_st_plt$target)),outcomes_)]), " incident ", gsub("inc","",unique(forecast_st_plt$target)))
names(target_labs) <- unique(forecast_st_plt$target)
inc_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%

inc_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%
filter(incid_cum=="inc") %>%
mutate(scenario_name = factor(scenario_name)) %>%
ggplot(aes(x = date)) +
Expand All @@ -303,14 +303,14 @@ for(usps in unique(forecast_st_plt$USPS)){
theme(legend.position = "bottom", legend.text = element_text(size=10),
axis.text.x = element_text(size=6, angle = 45))
plot(inc_st_plt)

if (plot_cum) {

target_labs <- paste0(str_to_title(outcomes_time_[match(gsub("cum","",unique(forecast_st_plt$target)),outcomes_)]), " cumulative ", gsub("cum","",unique(forecast_st_plt$target)))
names(target_labs) <- unique(forecast_st_plt$target)
cum_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%

cum_st_plt <- forecast_st_plt %>%
filter(USPS == usps) %>%
filter(incid_cum=="cum") %>%
mutate(scenario_name = factor(scenario_name)) %>%
ggplot(aes(x = date)) +
Expand All @@ -331,7 +331,7 @@ for(usps in unique(forecast_st_plt$USPS)){
labeller = as_labeller(target_labs)) +
theme(legend.position = "bottom", legend.text = element_text(size=10),
axis.text.x = element_text(size=6, angle = 45))

plot(cum_st_plt)
}
}
Expand Down

0 comments on commit ad733ca

Please sign in to comment.