diff --git a/flepimop/main_scripts/inference_slot.R b/flepimop/main_scripts/inference_slot.R index 9cc61ad76..ef0edd1e7 100644 --- a/flepimop/main_scripts/inference_slot.R +++ b/flepimop/main_scripts/inference_slot.R @@ -50,7 +50,8 @@ option_list = list( optparse::make_option(c("-L", "--reset_chimeric_on_accept"), action = "store", default = Sys.getenv("FLEPI_RESET_CHIMERICS", TRUE), type = 'logical', help = 'Should the chimeric parameters get reset to global parameters when a global acceptance occurs'), optparse::make_option(c("-M", "--memory_profiling"), action = "store", default = Sys.getenv("FLEPI_MEM_PROFILE", FALSE), type = 'logical', help = 'Should the memory profiling be run during iterations'), optparse::make_option(c("-P", "--memory_profiling_iters"), action = "store", default = Sys.getenv("FLEPI_MEM_PROF_ITERS", 100), type = 'integer', help = 'If doing memory profiling, after every X iterations run the profiler'), - optparse::make_option(c("-g", "--subpop_len"), action="store", default=Sys.getenv("SUBPOP_LENGTH", 5), type='integer', help = "number of digits in subpop") + optparse::make_option(c("-g", "--subpop_len"), action="store", default=Sys.getenv("SUBPOP_LENGTH", 5), type='integer', help = "number of digits in subpop"), + optparse::make_option(c("-a", "--incl_aggr_likelihood"), action = "store", default = Sys.getenv("INCL_AGGR_LIKELIHOOD", TRUE), type = 'logical', help = 'Should the likelihood be calculated with the aggregate estiamtes.') ) parser=optparse::OptionParser(option_list=option_list) @@ -88,6 +89,18 @@ if (opt$config == ""){ } config = flepicommon::load_config(opt$config) + +if (!is.null(config$inference$incl_aggr_likelihood)){ + print("Using config option for `incl_aggr_likelihood`.") + opt$incl_aggr_likelihood <- config$inference$incl_aggr_likelihood + if (!is.null(config$inference$total_ll_multiplier)){ + print("Using config option for `total_ll_multiplier`.") + opt$total_ll_multiplier <- config$inference$total_ll_multiplier + } else { + opt$total_ll_multiplier <- 1 + } +} + ## Check for errors in config --------------------------------------------------------------------- ## seeding section @@ -265,6 +278,19 @@ if (config$inference$do_inference){ dplyr::right_join(tidyr::expand_grid(subpop = unique(.$subpop), date = unique(.$date))) %>% dplyr::mutate_if(is.numeric, dplyr::coalesce, 0) + # add aggregate groundtruth to the obs data for the likelihood calc + if (opt$incl_aggr_likelihood){ + obs <- obs %>% + dplyr::bind_rows( + obs %>% + dplyr::select(date, where(is.numeric)) %>% + dplyr::group_by(date) %>% + dplyr::summarise(across(everything(), sum)) %>% # no likelihood is calculated for time periods with missing data for any subpop + dplyr::mutate(source = "Total", + subpop = "Total") + ) + } + subpopnames <- unique(obs[[obs_subpop]]) @@ -588,7 +614,19 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { # run if (config$inference$do_inference){ sim_hosp <- flepicommon::read_file_of_type(gsub(".*[.]","",this_global_files[['hosp_filename']]))(this_global_files[['hosp_filename']]) %>% - dplyr::filter(time >= min(obs$date),time <= max(obs$date)) + dplyr::filter(time >= min(obs$date), time <= max(obs$date)) + + # add aggregate groundtruth to the obs data for the likelihood calc + if (opt$incl_aggr_likelihood){ + sim_hosp <- sim_hosp %>% + dplyr::bind_rows( + sim_hosp %>% + dplyr::select(-tidyselect::all_of(obs_subpop), -tidyselect::starts_with("date")) %>% + dplyr::group_by(time) %>% + dplyr::summarise(dplyr::across(tidyselect::everything(), sum)) %>% # no likelihood is calculated for time periods with missing data for any subpop + dplyr::mutate(!!obs_subpop := "Total") + ) + } lhs <- unique(sim_hosp[[obs_subpop]]) rhs <- unique(names(data_stats)) @@ -624,6 +662,12 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { rm(sim_hosp) + # multiply aggregate likelihood by a factor if specified in config + if (opt$incl_aggr_likelihood){ + proposed_likelihood_data$ll[proposed_likelihood_data$subpop == "Total"] <- proposed_likelihood_data$ll[proposed_likelihood_data$subpop == "Total"] * opt$total_ll_multiplier + } + + # write proposed likelihood to global file arrow::write_parquet(proposed_likelihood_data, this_global_files[['llik_filename']])