diff --git a/flepimop/R_packages/flepicommon/R/compartments.R b/flepimop/R_packages/flepicommon/R/compartments.R index 31e0b5a32..51ff0c0b4 100644 --- a/flepimop/R_packages/flepicommon/R/compartments.R +++ b/flepimop/R_packages/flepicommon/R/compartments.R @@ -1,14 +1,14 @@ parse_compartment_names <- function(seir_config) { - compartment_frame <- tidyr::expand_grid(!!!seir_config$compartments) + compartment_df <- tidyr::expand_grid(!!!seir_config$compartments) compartment_names <- character(0) - for (component in compartment_frame) { + for (component in compartment_df) { if (any(grepl("_", component, fixed = TRUE))) { stop(paste("_", "is a reserved character, and cannot appear in compartment component names")) } compartment_names <- paste(compartment_names, component, sep = "_") } compartment_names <- stringr::str_sub(compartment_names, 2) - return(compartment_frame) + return(compartment_df) } assert <- function(bool, msg) { diff --git a/flepimop/gempyor_pkg/src/gempyor/compartments.py b/flepimop/gempyor_pkg/src/gempyor/compartments.py index b5953a4f5..ec87cf7e5 100644 --- a/flepimop/gempyor_pkg/src/gempyor/compartments.py +++ b/flepimop/gempyor_pkg/src/gempyor/compartments.py @@ -28,22 +28,55 @@ def __init__(self, seir_config=None, compartments_config=None, compartments_file raise ValueError("Compartments object not set, no config or file provided") return + def constructFromConfig(self, seir_config, compartment_config): + """ + This method is called by the constructor if the compartments are not loaded from a file. + It will parse the compartments and transitions from the configuration files. + It will populate self.compartments and self.transitions. + """ + self.compartments = self.parse_compartments(seir_config, compartment_config) + self.transitions = self.parse_transitions(seir_config, False) + def __eq__(self, other): return (self.transitions == other.transitions).all().all() and ( self.compartments == other.compartments ).all().all() def parse_compartments(self, seir_config, compartment_config): - compartment_frame = None + """ Parse the compartments from the configuration file: + seir_config: the configuration file for the SEIR model + compartment_config: the configuration file for the compartments + Example: if config says: + ``` + compartments: + infection_stage: ["S", "E", "I", "R"] + vaccination_stage: ["vaccinated", "unvaccinated"] + ``` + compartment_df is: + ``` + infection_stage vaccination_stage name + 0 S vaccinated S_vaccinated + 1 S unvaccinated S_unvaccinated + 2 E vaccinated E_vaccinated + 3 E unvaccinated E_unvaccinated + 4 I vaccinated I_vaccinated + 5 I unvaccinated I_unvaccinated + 6 R vaccinated R_vaccinated + 7 R unvaccinated R_unvaccinated + ``` + TODO: add tests + """ + + compartment_df = None for compartment_name, compartment_value in compartment_config.get().items(): tmp = pd.DataFrame({"key": 1, compartment_name: compartment_value}) - if compartment_frame is None: - compartment_frame = tmp + if compartment_df is None: + compartment_df = tmp else: - compartment_frame = pd.merge(compartment_frame, tmp, on="key") - compartment_frame = compartment_frame.drop(["key"], axis=1) - compartment_frame["name"] = compartment_frame.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1) - self.compartments = compartment_frame + compartment_df = pd.merge(compartment_df, tmp, on="key") + compartment_df = compartment_df.drop(["key"], axis=1) + compartment_df["name"] = compartment_df.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1) + return compartment_df def parse_transitions(self, seir_config, fake_config=False): rc = reduce( @@ -65,12 +98,14 @@ def access_original_config_by_multi_index(self, config_piece, index, dimension=N dimension = [None for i in index] tmp = [y for y in zip(index, range(len(index)), dimension)] tmp = zip(index, range(len(index)), dimension) - tmp = [list_access_element(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp] + tmp = [list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp] return tmp def expand_transition_elements(self, single_transition_config, problem_dimension): proportion_size = get_list_dimension(single_transition_config["proportional_to"]) new_transition_config = single_transition_config.copy() + + # replace "source" by the actual source from the config for p_idx in range(proportion_size): if new_transition_config["proportional_to"][p_idx] == "source": new_transition_config["proportional_to"][p_idx] = new_transition_config["source"] @@ -84,49 +119,92 @@ def expand_transition_elements(self, single_transition_config, problem_dimension new_transition_config["proportional_to"] = np.zeros(problem_dimension, dtype=object) new_transition_config["proportion_exponent"] = np.zeros(problem_dimension, dtype=object) - it = np.nditer(temp_array, flags=["multi_index"]) + it = np.nditer(temp_array, flags=["multi_index"]) # it is an iterator that will go through all the indexes of the array for x in it: - new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index) - ) - - new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index) - ) - - new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index) - ) + try: + new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string( + self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index) + ) + except Exception as e: + print(f"Error {e}:") + print(f">>> in expand_transition_elements for `source:` at index {it.multi_index}") + print(f">>> this transition source is: {single_transition_config['source']}") + print(f">>> this transition destination is: {single_transition_config['destination']}") + print(f"transition_dimension: {problem_dimension}") + raise e - new_transition_config["proportional_to"][it.multi_index] = as_list( - list_recursive_convert_to_string( - [ - self.access_original_config_by_multi_index( - single_transition_config["proportional_to"][p_idx], - it.multi_index, - problem_dimension, - True, - ) - for p_idx in range(proportion_size) - ] + try: + new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string( + self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index) ) - ) + except Exception as e: + print(f"Error {e}:") + print(f">>> in expand_transition_elements for `destination:` at index {it.multi_index}") + print(f">>> this transition source is: {single_transition_config['source']}") + print(f">>> this transition destination is: {single_transition_config['destination']}") + print(f"transition_dimension: {problem_dimension}") + raise e + + try: + new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string( + self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index) + ) + except Exception as e: + print(f"Error {e}:") + print(f">>> in expand_transition_elements for `rate:` at index {it.multi_index}") + print(f">>> this transition source is: {single_transition_config['source']}") + print(f">>> this transition destination is: {single_transition_config['destination']}") + print(f"transition_dimension: {problem_dimension}") + raise e - self.access_original_config_by_multi_index( - single_transition_config["proportion_exponent"][0], - it.multi_index, - problem_dimension, - ) - new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string( - [ + try: + new_transition_config["proportional_to"][it.multi_index] = as_list( + list_recursive_convert_to_string( + [ + self.access_original_config_by_multi_index( + single_transition_config["proportional_to"][p_idx], + it.multi_index, + problem_dimension, + True, + ) + for p_idx in range(proportion_size) + ] + ) + ) + except Exception as e: + print(f"Error {e}:") + print(f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}") + print(f">>> this transition source is: {single_transition_config['source']}") + print(f">>> this transition destination is: {single_transition_config['destination']}") + print(f"transition_dimension: {problem_dimension}") + raise e + + if "proportion_exponent" in single_transition_config: # if proportion_exponent is not defined, it is set to 1 + try: self.access_original_config_by_multi_index( - single_transition_config["proportion_exponent"][p_idx], + single_transition_config["proportion_exponent"][0], it.multi_index, problem_dimension, ) - for p_idx in range(proportion_size) - ] - ) + new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string( + [ + self.access_original_config_by_multi_index( + single_transition_config["proportion_exponent"][p_idx], + it.multi_index, + problem_dimension, + ) + for p_idx in range(proportion_size) + ] + ) + except Exception as e: + print(f"Error {e}:") + print(f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}") + print(f">>> this transition source is: {single_transition_config['source']}") + print(f">>> this transition destination is: {single_transition_config['destination']}") + print(f"transition_dimension: {problem_dimension}") + raise e + else: + new_transition_config["proportion_exponent"][it.multi_index] = ["1"] * proportion_size return new_transition_config @@ -216,6 +294,7 @@ def unformat_proportion_exponent(self, proportion_exponent_column, compartment_d return rc def parse_single_transition(self, seir_config, single_transition_config, fake_config=False): + ## This method relies on having run parse_compartments if not fake_config: single_transition_config = single_transition_config.get() @@ -303,9 +382,6 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i def get_ncomp(self) -> int: return len(self.compartments) - def constructFromConfig(self, seir_config, compartment_config): - self.parse_compartments(seir_config, compartment_config) - self.transitions = self.parse_transitions(seir_config, False) def get_transition_array(self): with Timer("SEIR.compartments"): @@ -630,7 +706,24 @@ def get_list_dimension(thing): return 1 +def list_access_element_safe(thing, idx, dimension=None, encapsulate_as_list=False): + try: + return list_access_element(thing, idx, dimension, encapsulate_as_list) + except Exception as e: + print(f"Error {e}:") + print(f">>> in list_access_element_safe for {thing} at index {idx}") + print(">>> This is often, but not always because the object above is a list (there are brackets around it).") + print(">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it.") + print(f"dimension: {dimension}") + raise e + + def list_access_element(thing, idx, dimension=None, encapsulate_as_list=False): + """ + This function is used to access elements in a list or a single element. + if list, it will return the element at index idx. + if not list, it will return the element itself, for any idx. + """ if not dimension is None: if dimension == 1: rc = as_list(thing) @@ -688,5 +781,5 @@ def export(): proportion_array, proportion_info, ) = comp.get_transition_array() - comp.toFile("compartments_file.csv", "transitions_file.csv") + comp.toFile("compartments_file.csv", "transitions_file.csv", write_parquet=False) print("wrote files 'compartments_file.csv', 'transitions_file.csv' ") diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index d8128361d..24f205c6d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -330,7 +330,8 @@ def __init__( # Config prep config.clear() config.read(user=False) - config.set_file(path_prefix + config_filepath) + + config.set_file(os.path.join(path_prefix, config_filepath)) self.seir_modifiers_scenario, self.outcome_modifiers_scenario = autodetect_scenarios(config) diff --git a/flepimop/gempyor_pkg/src/gempyor/model_info.py b/flepimop/gempyor_pkg/src/gempyor/model_info.py index 043c98876..54f981d8f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/model_info.py +++ b/flepimop/gempyor_pkg/src/gempyor/model_info.py @@ -21,15 +21,16 @@ class ModelInfo: """ Parse config and hold some results, with main config sections. ``` - # subpop_setup # Always required - # compartments # Required if running seir - # seir # Required if running seir - # initial_conditions # One of seeding or initial_conditions is required when running seir - # seeding # One of seeding or initial_conditions is required when running seir - # outcomes # Required if running outcomes - # seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters - # outcome_modifiers # Not required. If exists, every modifier will be applied to outcomes parameters - # inference # Required if running inference + subpop_setup # Always required + compartments # Required if running seir + parameters # required if running seir + seir # Required if running seir + initial_conditions # One of seeding or initial_conditions is required when running seir + seeding # One of seeding or initial_conditions is required when running seir + outcomes # Required if running outcomes + seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters + outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes + inference # Required if running inference ``` """ diff --git a/flepimop/gempyor_pkg/tests/seir/test_compartments.py b/flepimop/gempyor_pkg/tests/seir/test_compartments.py index 35d8f6893..1d4319e3b 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_compartments.py +++ b/flepimop/gempyor_pkg/tests/seir/test_compartments.py @@ -86,3 +86,4 @@ def test_ModelInfo_has_compartments_component(): ) assert type(s.compartments) == compartments.Compartments assert type(s.compartments) == compartments.Compartments +