From 13505840b3251746acbed4e79a69b656120e9319 Mon Sep 17 00:00:00 2001 From: Joseph Lemaitre Date: Mon, 30 Oct 2023 14:12:22 +0100 Subject: [PATCH] Add a new flepimop cli with click command group, currently support compartments plot/export. Solves #107 and lays the ground work for #106 --- flepimop/gempyor_pkg/setup.cfg | 1 + .../gempyor_pkg/src/gempyor/compartments.py | 62 ++++++++++++++++--- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/flepimop/gempyor_pkg/setup.cfg b/flepimop/gempyor_pkg/setup.cfg index 3de937cb7..ce7fd7f1b 100644 --- a/flepimop/gempyor_pkg/setup.cfg +++ b/flepimop/gempyor_pkg/setup.cfg @@ -38,6 +38,7 @@ install_requires = [options.entry_points] console_scripts = gempyor-outcomes = gempyor.simulate_outcome:simulate + flepimop = gempyor.cli:cli gempyor-seir = gempyor.simulate_seir:simulate gempyor-simulate = gempyor.simulate:simulate diff --git a/flepimop/gempyor_pkg/src/gempyor/compartments.py b/flepimop/gempyor_pkg/src/gempyor/compartments.py index 4ccb32e89..b75a0ab45 100644 --- a/flepimop/gempyor_pkg/src/gempyor/compartments.py +++ b/flepimop/gempyor_pkg/src/gempyor/compartments.py @@ -2,7 +2,7 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq - +import click from .utils import config, Timer, as_list from . import file_paths from functools import reduce @@ -248,10 +248,13 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co return rc - def toFile(self, compartments_file, transitions_file): + def toFile(self, compartments_file, transitions_file, write_parquet=False): out_df = self.compartments.copy() - pa_df = pa.Table.from_pandas(out_df, preserve_index=False) - pa.parquet.write_table(pa_df, compartments_file) + if write_parquet: + pa_df = pa.Table.from_pandas(out_df, preserve_index=False) + pa.parquet.write_table(pa_df, compartments_file) + else: + out_df.to_csv(compartments_file, index=False) out_df = self.transitions.copy() out_df["source"] = self.format_source(out_df["source"]) @@ -259,9 +262,11 @@ def toFile(self, compartments_file, transitions_file): out_df["rate"] = self.format_rate(out_df["rate"]) out_df["proportional_to"] = self.format_proportional_to(out_df["proportional_to"]) out_df["proportion_exponent"] = self.format_proportion_exponent(out_df["proportion_exponent"]) - pa_df = pa.Table.from_pandas(out_df, preserve_index=False) - pa.parquet.write_table(pa_df, transitions_file) - + if write_parquet: + pa_df = pa.Table.from_pandas(out_df, preserve_index=False) + pa.parquet.write_table(pa_df, transitions_file) + else: + out_df.to_csv(transitions_file, index=False) return def fromFile(self, compartments_file, transitions_file): @@ -489,8 +494,8 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names # TODO: instead of searching for the next array, better to just use the parameter shape. if not isinstance(substituted_formulas[i], np.ndarray): for k in range(len(substituted_formulas)): - if isinstance(substituted_formulas[k], np.ndarray): - substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k]) + if isinstance(substituted_formulas[k], np.ndarray): + substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k]) return np.array(substituted_formulas) @@ -643,3 +648,42 @@ def list_recursive_convert_to_string(thing): if type(thing) == list: return [list_recursive_convert_to_string(x) for x in thing] return str(thing) + + + +@click.group() +def compartments(): + pass + +# TODO: CLI arguments +@compartments.command() +def plot(): + assert config["compartments"].exists() + assert config["seir"].exists() + comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + + # TODO: this should be a command like build compartments. + ( + unique_strings, + transition_array, + proportion_array, + proportion_info, + ) = comp.get_transition_array() + + comp.plot(output_file="transition_graph", source_filters=[], destination_filters=[]) + + print("wrote file transition_graph") + +@compartments.command() +def export(): + assert config["compartments"].exists() + assert config["seir"].exists() + comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + ( + unique_strings, + transition_array, + proportion_array, + proportion_info, + ) = comp.get_transition_array() + comp.toFile('compartments_file.csv', 'transitions_file.csv') + print("wrote files 'compartments_file.csv', 'transitions_file.csv' ") \ No newline at end of file