Skip to content

Commit

Permalink
Add a new flepimop cli with click command group, currently support co…
Browse files Browse the repository at this point in the history
…mpartments plot/export. Solves #107 and lays the ground work for #106
  • Loading branch information
jcblemai committed Oct 30, 2023
1 parent 13542d1 commit 1350584
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
1 change: 1 addition & 0 deletions flepimop/gempyor_pkg/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ install_requires =
[options.entry_points]
console_scripts =
gempyor-outcomes = gempyor.simulate_outcome:simulate
flepimop = gempyor.cli:cli
gempyor-seir = gempyor.simulate_seir:simulate
gempyor-simulate = gempyor.simulate:simulate

Expand Down
62 changes: 53 additions & 9 deletions flepimop/gempyor_pkg/src/gempyor/compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

import click
from .utils import config, Timer, as_list
from . import file_paths
from functools import reduce
Expand Down Expand Up @@ -248,20 +248,25 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co

return rc

def toFile(self, compartments_file, transitions_file):
def toFile(self, compartments_file, transitions_file, write_parquet=False):
out_df = self.compartments.copy()
pa_df = pa.Table.from_pandas(out_df, preserve_index=False)
pa.parquet.write_table(pa_df, compartments_file)
if write_parquet:
pa_df = pa.Table.from_pandas(out_df, preserve_index=False)
pa.parquet.write_table(pa_df, compartments_file)
else:
out_df.to_csv(compartments_file, index=False)

out_df = self.transitions.copy()
out_df["source"] = self.format_source(out_df["source"])
out_df["destination"] = self.format_destination(out_df["destination"])
out_df["rate"] = self.format_rate(out_df["rate"])
out_df["proportional_to"] = self.format_proportional_to(out_df["proportional_to"])
out_df["proportion_exponent"] = self.format_proportion_exponent(out_df["proportion_exponent"])
pa_df = pa.Table.from_pandas(out_df, preserve_index=False)
pa.parquet.write_table(pa_df, transitions_file)

if write_parquet:
pa_df = pa.Table.from_pandas(out_df, preserve_index=False)
pa.parquet.write_table(pa_df, transitions_file)
else:
out_df.to_csv(transitions_file, index=False)
return

def fromFile(self, compartments_file, transitions_file):
Expand Down Expand Up @@ -489,8 +494,8 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names
# TODO: instead of searching for the next array, better to just use the parameter shape.
if not isinstance(substituted_formulas[i], np.ndarray):
for k in range(len(substituted_formulas)):
if isinstance(substituted_formulas[k], np.ndarray):
substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k])
if isinstance(substituted_formulas[k], np.ndarray):
substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k])

return np.array(substituted_formulas)

Expand Down Expand Up @@ -643,3 +648,42 @@ def list_recursive_convert_to_string(thing):
if type(thing) == list:
return [list_recursive_convert_to_string(x) for x in thing]
return str(thing)



@click.group()
def compartments():
pass

# TODO: CLI arguments
@compartments.command()
def plot():
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])

# TODO: this should be a command like build compartments.
(
unique_strings,
transition_array,
proportion_array,
proportion_info,
) = comp.get_transition_array()

comp.plot(output_file="transition_graph", source_filters=[], destination_filters=[])

print("wrote file transition_graph")

@compartments.command()
def export():
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])
(
unique_strings,
transition_array,
proportion_array,
proportion_info,
) = comp.get_transition_array()
comp.toFile('compartments_file.csv', 'transitions_file.csv')
print("wrote files 'compartments_file.csv', 'transitions_file.csv' ")

0 comments on commit 1350584

Please sign in to comment.