Skip to content

Commit

Permalink
Merge pull request #292 from HopkinsIDD/r-inference-fixes
Browse files Browse the repository at this point in the history
Fixing issues with inference and NAs
  • Loading branch information
jcblemai authored Sep 13, 2024
2 parents 3fe8d8f + e3ee55f commit fb5a9c5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 61 deletions.
6 changes: 4 additions & 2 deletions flepimop/R_packages/inference/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ periodAggregate <- function(data, dates, start_date = NULL, end_date = NULL, per
tmp <- tmp %>%
tidyr::unite("time_unit", names(tmp)[grepl("time_unit_", names(tmp))]) %>%
dplyr::group_by(time_unit) %>%
dplyr::summarize(first_date = min(date), value = aggregator(value), valid = period_unit_validator(date,time_unit)) %>%
dplyr::summarize(last_date = max(date),first_date = min(date), value = aggregator(value), valid = period_unit_validator(date,time_unit)) %>%
dplyr::ungroup() %>%
dplyr::arrange(first_date) %>%
dplyr::filter(valid)
return(matrix(tmp$value, ncol = 1, dimnames = list(as.character(tmp$first_date))))
# return(matrix(tmp$value, ncol = 1, dimnames = list(as.character(tmp$first_date))))
return(matrix(tmp$value, ncol = 1, dimnames = list(as.character(tmp$last_date))))

}


Expand Down
121 changes: 66 additions & 55 deletions flepimop/R_packages/inference/R/inference_slot_runner_funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,18 @@ aggregate_and_calc_loc_likelihoods <- function(
this_location_log_likelihood <- 0
for (var in names(ground_truth_data[[location]])) {


obs_tmp1 <- ground_truth_data[[location]][[var]]
obs_tmp <- obs_tmp1[!is.na(obs_tmp1$data_var) & !is.na(obs_tmp1$date),]
sim_tmp1 <- this_location_modeled_outcome[[var]]
sim_tmp <- sim_tmp1[match(lubridate::as_date(sim_tmp1$date),
lubridate::as_date(obs_tmp$date)),] %>% na.omit()


this_location_log_likelihood <- this_location_log_likelihood +
## Actually compute likelihood for this location and statistic here:
sum(inference::logLikStat(
obs = ground_truth_data[[location]][[var]]$data_var,
sim = this_location_modeled_outcome[[var]]$sim_var,
obs = as.numeric(obs_tmp$data_var),
sim = as.numeric(sim_tmp$sim_var),
dist = targets_config[[var]]$likelihood$dist,
param = targets_config[[var]]$likelihood$param,
add_one = targets_config[[var]]$add_one
Expand Down Expand Up @@ -615,72 +621,77 @@ initialize_mcmc_first_block <- function(
## initial conditions (init)

if (!is.null(config$initial_conditions)){
if(config$initial_conditions$method != "plugin"){

if ("init_filename" %in% global_file_names) {

if (config$initial_conditions$method %in% c("FromFile", "SetInitialConditions")){

if (is.null(config$initial_conditions$initial_conditions_file)) {
stop("ERROR: Initial conditions file needs to be specified in the config under `initial_conditions:initial_conditions_file`")
}
initial_init_file <- config$initial_conditions$initial_conditions_file

} else if (config$initial_conditions$method %in% c("InitialConditionsFolderDraw", "SetInitialConditionsFolderDraw", "plugin")) {
print("Initial conditions in inference has not been fully implemented yet for the 'folder draw' methods,
and no copying to global or chimeric files is being done.")

if (is.null(config$initial_conditions$initial_file_type)) {
stop("ERROR: Initial conditions file needs to be specified in the config under `initial_conditions:initial_conditions_file`")
}
initial_init_file <- global_files[[paste0(config$initial_conditions$initial_file_type, "_filename")]]
}


if (!file.exists(initial_init_file)) {
stop("ERROR: Initial conditions file specified but does not exist.")
}

if (config$initial_conditions$method %in% c("FromFile", "SetInitialConditions")){

if (grepl(".csv", initial_init_file)){
initial_init <- readr::read_csv(initial_init_file,show_col_types = FALSE)
}else{
initial_init <- arrow::read_parquet(initial_init_file)
if (is.null(config$initial_conditions$initial_conditions_file)) {
stop("ERROR: Initial conditions file needs to be specified in the config under `initial_conditions:initial_conditions_file`")
}
initial_init_file <- config$initial_conditions$initial_conditions_file

# if the initial conditions file contains a 'date' column, filter for config$start_date
} else if (config$initial_conditions$method %in% c("InitialConditionsFolderDraw", "SetInitialConditionsFolderDraw")) {
print("Initial conditions in inference has not been fully implemented yet for the 'folder draw' methods,
and no copying to global or chimeric files is being done.")

if("date" %in% colnames(initial_init)){

initial_init <- initial_init %>%
dplyr::mutate(date = as.POSIXct(date, tz="UTC")) %>%
dplyr::filter(date == as.POSIXct(paste0(config$start_date, " 00:00:00"), tz="UTC"))

if (nrow(initial_init) == 0) {
stop("ERROR: Initial conditions file specified but does not contain the start date.")
}

if (is.null(config$initial_conditions$initial_file_type)) {
stop("ERROR: Initial conditions file needs to be specified in the config under `initial_conditions:initial_conditions_file`")
}

arrow::write_parquet(initial_init, global_files[["init_filename"]])
initial_init_file <- global_files[[paste0(config$initial_conditions$initial_file_type, "_filename")]]
}


if (!file.exists(initial_init_file)) {
stop("ERROR: Initial conditions file specified but does not exist.")
}

if (grepl(".csv", initial_init_file)){
initial_init <- readr::read_csv(initial_init_file,show_col_types = FALSE)
}else{
initial_init <- arrow::read_parquet(initial_init_file)
}

# if the initial conditions file contains a 'date' column, filter for config$start_date

if("date" %in% colnames(initial_init)){

initial_init <- initial_init %>%
dplyr::mutate(date = as.POSIXct(date, tz="UTC")) %>%
dplyr::filter(date == as.POSIXct(paste0(config$start_date, " 00:00:00"), tz="UTC"))

if (nrow(initial_init) == 0) {
stop("ERROR: Initial conditions file specified but does not contain the start date.")
}

}

arrow::write_parquet(initial_init, global_files[["init_filename"]])
}

# if the initial conditions file contains a 'date' column, filter for config$start_date
if (grepl(".csv", global_files[["init_filename"]])){
initial_init <- readr::read_csv(global_files[["init_filename"]],show_col_types = FALSE)
initial_init <- readr::read_csv(global_files[["init_filename"]],show_col_types = FALSE)
}else{
initial_init <- arrow::read_parquet(global_files[["init_filename"]])
initial_init <- arrow::read_parquet(global_files[["init_filename"]])
}

if("date" %in% colnames(initial_init)){
initial_init <- initial_init %>%
dplyr::mutate(date = as.POSIXct(date, tz="UTC")) %>%
dplyr::filter(date == as.POSIXct(paste0(config$start_date, " 00:00:00"), tz="UTC"))
if (nrow(initial_init) == 0) {
stop("ERROR: Initial conditions file specified but does not contain the start date.")
}

initial_init <- initial_init %>%
dplyr::mutate(date = as.POSIXct(date, tz="UTC")) %>%
dplyr::filter(date == as.POSIXct(paste0(config$start_date, " 00:00:00"), tz="UTC"))

if (nrow(initial_init) == 0) {
stop("ERROR: Initial conditions file specified but does not contain the start date.")
}

}
arrow::write_parquet(initial_init, global_files[["init_filename"]])
}else if(config$initial_conditions$method == "plugin"){
print("Initial conditions files generated by gempyor using plugin method.")
}
}


Expand Down
8 changes: 4 additions & 4 deletions flepimop/main_scripts/inference_slot.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ option_list = list(
optparse::make_option(c("-M", "--memory_profiling"), action = "store", default = Sys.getenv("FLEPI_MEM_PROFILE", FALSE), type = 'logical', help = 'Should the memory profiling be run during iterations'),
optparse::make_option(c("-P", "--memory_profiling_iters"), action = "store", default = Sys.getenv("FLEPI_MEM_PROF_ITERS", 100), type = 'integer', help = 'If doing memory profiling, after every X iterations run the profiler'),
optparse::make_option(c("-g", "--subpop_len"), action="store", default=Sys.getenv("SUBPOP_LENGTH", 5), type='integer', help = "number of digits in subpop"),
optparse::make_option(c("-a", "--incl_aggr_likelihood"), action = "store", default = Sys.getenv("INCL_AGGR_LIKELIHOOD", TRUE), type = 'logical', help = 'Should the likelihood be calculated with the aggregate estiamtes.')
optparse::make_option(c("-a", "--incl_aggr_likelihood"), action = "store", default = Sys.getenv("INCL_AGGR_LIKELIHOOD", FALSE), type = 'logical', help = 'Should the likelihood be calculated with the aggregate estiamtes.')
)

parser=optparse::OptionParser(option_list=option_list)
Expand Down Expand Up @@ -268,12 +268,12 @@ if (config$inference$do_inference){
obs <- suppressMessages(
readr::read_csv(config$inference$gt_data_path,
col_types = readr::cols(date = readr::col_date(),
source = readr::col_character(),
# source = readr::col_character(),
subpop = readr::col_character(),
.default = readr::col_double()), )) %>%
dplyr::filter(subpop %in% subpops_, date >= gt_start_date, date <= gt_end_date) %>%
dplyr::right_join(tidyr::expand_grid(subpop = unique(.$subpop), date = unique(.$date))) %>%
dplyr::mutate_if(is.numeric, dplyr::coalesce, 0)
dplyr::right_join(tidyr::expand_grid(subpop = unique(.$subpop), date = unique(.$date))) #%>%
# dplyr::mutate_if(is.numeric, dplyr::coalesce, 0)


# add aggregate groundtruth to the obs data for the likelihood calc
Expand Down

0 comments on commit fb5a9c5

Please sign in to comment.