Skip to content

Commit

Permalink
Merge pull request #224 from HopkinsIDD/readability
Browse files Browse the repository at this point in the history
Improve readability and error messages for gempyor
  • Loading branch information
jcblemai authored Jul 12, 2024
2 parents ad08cce + 77c3cab commit 35b0e9e
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 60 deletions.
6 changes: 3 additions & 3 deletions flepimop/R_packages/flepicommon/R/compartments.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
parse_compartment_names <- function(seir_config) {
compartment_frame <- tidyr::expand_grid(!!!seir_config$compartments)
compartment_df <- tidyr::expand_grid(!!!seir_config$compartments)
compartment_names <- character(0)
for (component in compartment_frame) {
for (component in compartment_df) {
if (any(grepl("_", component, fixed = TRUE))) {
stop(paste("_", "is a reserved character, and cannot appear in compartment component names"))
}
compartment_names <- paste(compartment_names, component, sep = "_")
}
compartment_names <- stringr::str_sub(compartment_names, 2)
return(compartment_frame)
return(compartment_df)
}

assert <- function(bool, msg) {
Expand Down
187 changes: 140 additions & 47 deletions flepimop/gempyor_pkg/src/gempyor/compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,55 @@ def __init__(self, seir_config=None, compartments_config=None, compartments_file
raise ValueError("Compartments object not set, no config or file provided")
return

def constructFromConfig(self, seir_config, compartment_config):
"""
This method is called by the constructor if the compartments are not loaded from a file.
It will parse the compartments and transitions from the configuration files.
It will populate self.compartments and self.transitions.
"""
self.compartments = self.parse_compartments(seir_config, compartment_config)
self.transitions = self.parse_transitions(seir_config, False)

def __eq__(self, other):
return (self.transitions == other.transitions).all().all() and (
self.compartments == other.compartments
).all().all()

def parse_compartments(self, seir_config, compartment_config):
compartment_frame = None
""" Parse the compartments from the configuration file:
seir_config: the configuration file for the SEIR model
compartment_config: the configuration file for the compartments
Example: if config says:
```
compartments:
infection_stage: ["S", "E", "I", "R"]
vaccination_stage: ["vaccinated", "unvaccinated"]
```
compartment_df is:
```
infection_stage vaccination_stage name
0 S vaccinated S_vaccinated
1 S unvaccinated S_unvaccinated
2 E vaccinated E_vaccinated
3 E unvaccinated E_unvaccinated
4 I vaccinated I_vaccinated
5 I unvaccinated I_unvaccinated
6 R vaccinated R_vaccinated
7 R unvaccinated R_unvaccinated
```
TODO: add tests
"""

compartment_df = None
for compartment_name, compartment_value in compartment_config.get().items():
tmp = pd.DataFrame({"key": 1, compartment_name: compartment_value})
if compartment_frame is None:
compartment_frame = tmp
if compartment_df is None:
compartment_df = tmp
else:
compartment_frame = pd.merge(compartment_frame, tmp, on="key")
compartment_frame = compartment_frame.drop(["key"], axis=1)
compartment_frame["name"] = compartment_frame.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1)
self.compartments = compartment_frame
compartment_df = pd.merge(compartment_df, tmp, on="key")
compartment_df = compartment_df.drop(["key"], axis=1)
compartment_df["name"] = compartment_df.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1)
return compartment_df

def parse_transitions(self, seir_config, fake_config=False):
rc = reduce(
Expand All @@ -65,12 +98,14 @@ def access_original_config_by_multi_index(self, config_piece, index, dimension=N
dimension = [None for i in index]
tmp = [y for y in zip(index, range(len(index)), dimension)]
tmp = zip(index, range(len(index)), dimension)
tmp = [list_access_element(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp]
tmp = [list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp]
return tmp

def expand_transition_elements(self, single_transition_config, problem_dimension):
proportion_size = get_list_dimension(single_transition_config["proportional_to"])
new_transition_config = single_transition_config.copy()

# replace "source" by the actual source from the config
for p_idx in range(proportion_size):
if new_transition_config["proportional_to"][p_idx] == "source":
new_transition_config["proportional_to"][p_idx] = new_transition_config["source"]
Expand All @@ -84,49 +119,92 @@ def expand_transition_elements(self, single_transition_config, problem_dimension
new_transition_config["proportional_to"] = np.zeros(problem_dimension, dtype=object)
new_transition_config["proportion_exponent"] = np.zeros(problem_dimension, dtype=object)

it = np.nditer(temp_array, flags=["multi_index"])
it = np.nditer(temp_array, flags=["multi_index"]) # it is an iterator that will go through all the indexes of the array
for x in it:
new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index)
)

new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index)
)

new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index)
)
try:
new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index)
)
except Exception as e:
print(f"Error {e}:")
print(f">>> in expand_transition_elements for `source:` at index {it.multi_index}")
print(f">>> this transition source is: {single_transition_config['source']}")
print(f">>> this transition destination is: {single_transition_config['destination']}")
print(f"transition_dimension: {problem_dimension}")
raise e

new_transition_config["proportional_to"][it.multi_index] = as_list(
list_recursive_convert_to_string(
[
self.access_original_config_by_multi_index(
single_transition_config["proportional_to"][p_idx],
it.multi_index,
problem_dimension,
True,
)
for p_idx in range(proportion_size)
]
try:
new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index)
)
)
except Exception as e:
print(f"Error {e}:")
print(f">>> in expand_transition_elements for `destination:` at index {it.multi_index}")
print(f">>> this transition source is: {single_transition_config['source']}")
print(f">>> this transition destination is: {single_transition_config['destination']}")
print(f"transition_dimension: {problem_dimension}")
raise e

try:
new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string(
self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index)
)
except Exception as e:
print(f"Error {e}:")
print(f">>> in expand_transition_elements for `rate:` at index {it.multi_index}")
print(f">>> this transition source is: {single_transition_config['source']}")
print(f">>> this transition destination is: {single_transition_config['destination']}")
print(f"transition_dimension: {problem_dimension}")
raise e

self.access_original_config_by_multi_index(
single_transition_config["proportion_exponent"][0],
it.multi_index,
problem_dimension,
)
new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string(
[
try:
new_transition_config["proportional_to"][it.multi_index] = as_list(
list_recursive_convert_to_string(
[
self.access_original_config_by_multi_index(
single_transition_config["proportional_to"][p_idx],
it.multi_index,
problem_dimension,
True,
)
for p_idx in range(proportion_size)
]
)
)
except Exception as e:
print(f"Error {e}:")
print(f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}")
print(f">>> this transition source is: {single_transition_config['source']}")
print(f">>> this transition destination is: {single_transition_config['destination']}")
print(f"transition_dimension: {problem_dimension}")
raise e

if "proportion_exponent" in single_transition_config: # if proportion_exponent is not defined, it is set to 1
try:
self.access_original_config_by_multi_index(
single_transition_config["proportion_exponent"][p_idx],
single_transition_config["proportion_exponent"][0],
it.multi_index,
problem_dimension,
)
for p_idx in range(proportion_size)
]
)
new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string(
[
self.access_original_config_by_multi_index(
single_transition_config["proportion_exponent"][p_idx],
it.multi_index,
problem_dimension,
)
for p_idx in range(proportion_size)
]
)
except Exception as e:
print(f"Error {e}:")
print(f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}")
print(f">>> this transition source is: {single_transition_config['source']}")
print(f">>> this transition destination is: {single_transition_config['destination']}")
print(f"transition_dimension: {problem_dimension}")
raise e
else:
new_transition_config["proportion_exponent"][it.multi_index] = ["1"] * proportion_size

return new_transition_config

Expand Down Expand Up @@ -216,6 +294,7 @@ def unformat_proportion_exponent(self, proportion_exponent_column, compartment_d
return rc

def parse_single_transition(self, seir_config, single_transition_config, fake_config=False):

## This method relies on having run parse_compartments
if not fake_config:
single_transition_config = single_transition_config.get()
Expand Down Expand Up @@ -303,9 +382,6 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i
def get_ncomp(self) -> int:
return len(self.compartments)

def constructFromConfig(self, seir_config, compartment_config):
self.parse_compartments(seir_config, compartment_config)
self.transitions = self.parse_transitions(seir_config, False)

def get_transition_array(self):
with Timer("SEIR.compartments"):
Expand Down Expand Up @@ -630,7 +706,24 @@ def get_list_dimension(thing):
return 1


def list_access_element_safe(thing, idx, dimension=None, encapsulate_as_list=False):
try:
return list_access_element(thing, idx, dimension, encapsulate_as_list)
except Exception as e:
print(f"Error {e}:")
print(f">>> in list_access_element_safe for {thing} at index {idx}")
print(">>> This is often, but not always because the object above is a list (there are brackets around it).")
print(">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it.")
print(f"dimension: {dimension}")
raise e


def list_access_element(thing, idx, dimension=None, encapsulate_as_list=False):
"""
This function is used to access elements in a list or a single element.
if list, it will return the element at index idx.
if not list, it will return the element itself, for any idx.
"""
if not dimension is None:
if dimension == 1:
rc = as_list(thing)
Expand Down Expand Up @@ -688,5 +781,5 @@ def export():
proportion_array,
proportion_info,
) = comp.get_transition_array()
comp.toFile("compartments_file.csv", "transitions_file.csv")
comp.toFile("compartments_file.csv", "transitions_file.csv", write_parquet=False)
print("wrote files 'compartments_file.csv', 'transitions_file.csv' ")
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def __init__(
# Config prep
config.clear()
config.read(user=False)
config.set_file(path_prefix + config_filepath)

config.set_file(os.path.join(path_prefix, config_filepath))

self.seir_modifiers_scenario, self.outcome_modifiers_scenario = autodetect_scenarios(config)

Expand Down
19 changes: 10 additions & 9 deletions flepimop/gempyor_pkg/src/gempyor/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ class ModelInfo:
"""
Parse config and hold some results, with main config sections.
```
# subpop_setup # Always required
# compartments # Required if running seir
# seir # Required if running seir
# initial_conditions # One of seeding or initial_conditions is required when running seir
# seeding # One of seeding or initial_conditions is required when running seir
# outcomes # Required if running outcomes
# seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters
# outcome_modifiers # Not required. If exists, every modifier will be applied to outcomes parameters
# inference # Required if running inference
subpop_setup # Always required
compartments # Required if running seir
parameters # required if running seir
seir # Required if running seir
initial_conditions # One of seeding or initial_conditions is required when running seir
seeding # One of seeding or initial_conditions is required when running seir
outcomes # Required if running outcomes
seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters
outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes
inference # Required if running inference
```
"""

Expand Down
1 change: 1 addition & 0 deletions flepimop/gempyor_pkg/tests/seir/test_compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ def test_ModelInfo_has_compartments_component():
)
assert type(s.compartments) == compartments.Compartments
assert type(s.compartments) == compartments.Compartments

0 comments on commit 35b0e9e

Please sign in to comment.