From bb207a203c5619c7f002829da3ba2cf4325a6bb0 Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:05:07 -0400 Subject: [PATCH 01/20] fix errors in aggregation stuff --- .../{config_library => }/config_sample_2pop_inference.yml | 8 ++++---- flepimop/main_scripts/inference_slot.R | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) rename examples/tutorial_two_subpops/{config_library => }/config_sample_2pop_inference.yml (98%) diff --git a/examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml b/examples/tutorial_two_subpops/config_sample_2pop_inference.yml similarity index 98% rename from examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml rename to examples/tutorial_two_subpops/config_sample_2pop_inference.yml index e6f644201..bd3cba6d5 100644 --- a/examples/tutorial_two_subpops/config_library/config_sample_2pop_inference.yml +++ b/examples/tutorial_two_subpops/config_sample_2pop_inference.yml @@ -47,7 +47,7 @@ seir: seir_modifiers: scenarios: - - Ro_all + - inference modifiers: Ro_mod: # assume same for all subpopulations method: SinglePeriodModifier @@ -86,7 +86,7 @@ seir_modifiers: sd: 0.025 a: -0.1 b: 0.1 - Ro_all: + inference: method: StackedModifier modifiers: ["Ro_mod","Ro_lockdown"] @@ -122,10 +122,10 @@ outcomes: outcome_modifiers: scenarios: - - test_limits + - all modifiers: # assume that due to limitations in testing, initially the case detection probability was lower - test_limits: + all: method: SinglePeriodModifier parameter: incidCase subpop: "all" diff --git a/flepimop/main_scripts/inference_slot.R b/flepimop/main_scripts/inference_slot.R index 90d289640..031838c03 100644 --- a/flepimop/main_scripts/inference_slot.R +++ b/flepimop/main_scripts/inference_slot.R @@ -91,6 +91,7 @@ if (opt$config == ""){ } config = flepicommon::load_config(opt$config) +opt$total_ll_multiplier <- 1 if (!is.null(config$inference$incl_aggr_likelihood)){ print("Using config option for `incl_aggr_likelihood`.") opt$incl_aggr_likelihood <- config$inference$incl_aggr_likelihood @@ -450,7 +451,7 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { autowrite_seir = TRUE ) }, error = function(e) { - print("GempyorInference failed to run (call on l. 405 of inference_slot.R).") + print("GempyorInference failed to run (call on l. 443 of inference_slot.R).") print("Here is all the debug information I could find:") for(m in reticulate::py_last_error()) print(m) stop("GempyorInference failed to run... stopping") @@ -637,8 +638,8 @@ for(seir_modifiers_scenario in seir_modifiers_scenarios) { sim_hosp <- sim_hosp %>% dplyr::bind_rows( sim_hosp %>% - dplyr::select(-tidyselect::all_of(obs_subpop), -tidyselect::starts_with("date")) %>% - dplyr::group_by(time) %>% + dplyr::select(-tidyselect::all_of(obs_subpop)) %>% + dplyr::group_by(date) %>% dplyr::summarise(dplyr::across(tidyselect::everything(), sum)) %>% # no likelihood is calculated for time periods with missing data for any subpop dplyr::mutate(!!obs_subpop := "Total") ) From 9b99b89b823ada1c7c9a8a06d268e99cd91e15d3 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 11:46:20 -0400 Subject: [PATCH 02/20] Overhual write_df unit tests. Overhauled the gempyor.utils.write_df unit tests by placing them in a new file with a class grouping similar fixtures. Added tests for the NotImplementedError, writing to csv, and writing to parquet. --- .../gempyor_pkg/tests/utils/test_write_df.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_write_df.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py new file mode 100644 index 000000000..d465f6434 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -0,0 +1,127 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any + +import pytest +import pandas as pd + +from gempyor.utils import write_df + + +class TestWriteDf: + """ + Unit tests for the `gempyor.utils.write_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that write_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + write_df(fname=temp_file.name, df=self.sample_df) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_write_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests writing a DataFrame to a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".csv", + path_reader=lambda x: pd.read_csv(x, index_col=False), + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_write_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests writing a DataFrame to a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".parquet", + path_reader=lambda x: pd.read_parquet(x), + ) + + def _test_write_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + df: pd.DataFrame, + extension: str, + suffix: str | None, + path_reader: Callable[[os.PathLike], pd.DataFrame], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + df: The DataFrame to write. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_reader: A function to read the DataFrame from the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + write_df(fname=fname_transformer(temp_path), df=df, extension=extension) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = path_reader(temp_path) + assert test_df.equals(df) From 6ae23b34ca3c094ff26371acf12a742816363926 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:19:56 -0400 Subject: [PATCH 03/20] Overhual read_df unit tests Overhauled the gempyor.utils.read_df unit tests by blacing them in a new file with a class for grouping similar fixtures. Added tests for the NotImplementedError, reading from csv, and reading from parquet. --- .../gempyor_pkg/tests/utils/test_read_df.py | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_read_df.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py new file mode 100644 index 000000000..d25ed9a59 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -0,0 +1,128 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any + +import pytest +import pandas as pd + +from gempyor.utils import read_df + + +class TestReadDf: + """ + Unit tests for the `gempyor.utils.read_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that write_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + read_df(fname=temp_file.name) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + read_df(fname=fname, extension="txt") + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_read_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests reading a DataFrame from a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".csv", + path_writer=lambda p, df: df.to_csv(p, index=False) + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_read_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + ) -> None: + """ + Tests reading a DataFrame from a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".parquet", + path_writer=lambda p, df: df.to_parquet(p, engine='pyarrow', index=False), + ) + + def _test_read_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: str, + suffix: str | None, + path_writer: Callable[[os.PathLike, pd.DataFrame], None], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_writer: A function to write the DataFrame to the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + path_writer(temp_path, self.sample_df) + test_df = read_df(fname=fname_transformer(temp_path), extension=extension) + assert isinstance(test_df, pd.DataFrame) + assert temp_path.stat().st_size > 0 + assert test_df.equals(self.sample_df) From 9ad1abdd9cfc898f69c14144162988533726a5ea Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:38:24 -0400 Subject: [PATCH 04/20] Formatted read_df tests, explicit parquet engine * Formatted the `tests/utils/test_read_df.py` file. * Added `engine="pyarrow"` to `write_df` unit tests. --- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 10 +++++----- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index d25ed9a59..0b2334602 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -41,7 +41,7 @@ def test_raises_not_implemented_error(self) -> None: with NamedTemporaryFile(suffix=".txt") as temp_file: fname = temp_file.name[:-4] read_df(fname=fname, extension="txt") - + @pytest.mark.parametrize( "fname_transformer,extension", [ @@ -69,9 +69,9 @@ def test_read_csv_dataframe( fname_transformer=fname_transformer, extension=extension, suffix=".csv", - path_writer=lambda p, df: df.to_csv(p, index=False) + path_writer=lambda p, df: df.to_csv(p, index=False), ) - + @pytest.mark.parametrize( "fname_transformer,extension", [ @@ -99,9 +99,9 @@ def test_read_parquet_dataframe( fname_transformer=fname_transformer, extension=extension, suffix=".parquet", - path_writer=lambda p, df: df.to_parquet(p, engine='pyarrow', index=False), + path_writer=lambda p, df: df.to_parquet(p, engine="pyarrow", index=False), ) - + def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index d465f6434..aad3bc1c4 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -94,7 +94,7 @@ def test_write_parquet_dataframe( df=self.sample_df, extension=extension, suffix=".parquet", - path_reader=lambda x: pd.read_parquet(x), + path_reader=lambda x: pd.read_parquet(x, engine="pyarrow"), ) def _test_write_df( From 3d22841f76812f2de65c0eaaf087c9b76e435c57 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 12:54:11 -0400 Subject: [PATCH 05/20] Reorganized utils.py Moved read_df to be next to write_df in utils.py. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 4d3209061..2f13f0f57 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -38,6 +38,23 @@ def write_df(fname: str, df: pd.DataFrame, extension: str = ""): raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") +def read_df(fname: str, extension: str = "") -> pd.DataFrame: + """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension + can be provided as an argument or it is infered""" + fname = str(fname) + if extension: # Empty strings are falsy in python + fname = f"{fname}.{extension}" + extension = fname.split(".")[-1] + if extension == "csv": + # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent + df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) + elif extension == "parquet": + df = pa.parquet.read_table(fname).to_pandas() + else: + raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") + return df + + def command_safe_run(command, command_name="mycommand", fail_on_fail=True): import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen @@ -62,23 +79,6 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True): return sr.returncode, stdout, stderr -def read_df(fname: str, extension: str = "") -> pd.DataFrame: - """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension - can be provided as an argument or it is infered""" - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent - df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) - elif extension == "parquet": - df = pa.parquet.read_table(fname).to_pandas() - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") - return df - - def add_method(cls): "Decorator to add a method to a class" From df794e92609d1b56a2b8faf9340d74364c242902 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:02:08 -0400 Subject: [PATCH 06/20] Documented and extended write_df * Documented the gempyor.utils.write_df function using the Google style guide. * Extended write_df to explicitly support os.PathLike types for fname (was implicitly supported) and added support for bytes. * Changed the file manipulation logic to use pathlib rather than manipulating strings in write_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 49 ++++++++++++++++------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 2f13f0f57..bace01582 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -16,26 +16,47 @@ from gempyor import file_paths from typing import List, Dict from botocore.exceptions import ClientError +from pathlib import Path logger = logging.getLogger(__name__) config = confuse.Configuration("flepiMoP", read=False) -def write_df(fname: str, df: pd.DataFrame, extension: str = ""): - """write without index, so assume the index has been put a column""" - # cast to str to use .split in case fname is a PosixPath - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - df.to_csv(fname, index=False) - elif extension == "parquet": - df = pa.Table.from_pandas(df, preserve_index=False) - pa.parquet.write_table(df, fname) - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") +def write_df( + fname: str | bytes | os.PathLike, + df: pd.DataFrame, + extension: str = "", +) -> None: + """Writes a pandas DataFrame without its index to a file. + + Writes a pandas DataFrame to either a CSV or Parquet file without its index and can + infer which format to use based on the extension given in `fname` or based on + explicit `extension`. + + Args: + fname: The name of the file to write to. + df: A pandas DataFrame whose contents to write, but without its index. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + None + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Write df to either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return df.to_csv(path, index=False) + elif path.suffix == ".parquet": + return df.to_parquet(path, index=False, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + ) def read_df(fname: str, extension: str = "") -> pd.DataFrame: From 19078a03595284ab42d9ef7ee5d4cc0fa7a8e4fe Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:48:39 -0400 Subject: [PATCH 07/20] Add test for read_df with subpop column * Added unit test reading a file with a column called 'subpop', converted to a string when the file is a csv and left unaltered when the file is a parquet file. * Typo in test_raises_not_implemented_error docstring. --- .../gempyor_pkg/tests/utils/test_read_df.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index 0b2334602..c4c9a14b0 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -5,6 +5,7 @@ import pytest import pandas as pd +from pandas.api.types import is_object_dtype, is_numeric_dtype from gempyor.utils import read_df @@ -23,9 +24,16 @@ class TestReadDf: } ) + subpop_df: pd.DataFrame = pd.DataFrame( + { + "subpop": [1, 2, 3, 4], + "value": [5, 6, 7, 8], + } + ) + def test_raises_not_implemented_error(self) -> None: """ - Tests that write_df raises a NotImplementedError for unsupported file + Tests that read_df raises a NotImplementedError for unsupported file extensions. """ with pytest.raises( @@ -102,6 +110,38 @@ def test_read_parquet_dataframe( path_writer=lambda p, df: df.to_parquet(p, engine="pyarrow", index=False), ) + def test_subpop_is_cast_as_str(self) -> None: + """ + Tests that read_df returns an object dtype for the column 'subpop' when reading + a csv file, but not when reading a parquet file. + """ + # First confirm the dtypes of our test DataFrame + assert is_numeric_dtype(self.subpop_df["subpop"]) + assert is_numeric_dtype(self.subpop_df["value"]) + # Test that the subpop column is converted to a string for a csv file + with NamedTemporaryFile(suffix=".csv") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert self.subpop_df.to_csv(temp_path, index=False) is None + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_object_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + # Test that the subpop column remains unaltered for a parquet file + with NamedTemporaryFile(suffix=".parquet") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + self.subpop_df.to_parquet(temp_path, engine="pyarrow", index=False) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_numeric_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], From fd4021cf5a3b7955d9aab7cf1cf76724f849fb73 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 13:59:49 -0400 Subject: [PATCH 08/20] Documented and extended read_df * Documented the gempyor.utils.read_df function using the Google style guide. * Extended read_df to explicitly support os.PathLike types for fname (was implicitly supported) and added support for bytes. * Changed the file manipulation logic to use pathlib rather than manipulating strings in read_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 47 +++++++++++++++-------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index bace01582..f59646dc0 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -59,21 +59,38 @@ def write_df( ) -def read_df(fname: str, extension: str = "") -> pd.DataFrame: - """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension - can be provided as an argument or it is infered""" - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent - df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) - elif extension == "parquet": - df = pa.parquet.read_table(fname).to_pandas() - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") - return df +def read_df(fname: str | bytes | os.PathLike, extension: str = "") -> pd.DataFrame: + """Reads a pandas DataFrame from either a CSV or Parquet file. + + Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format + to use based on the extension given in `fname` or based on explicit `extension`. If + the file being read is a csv with a column called 'subpop' then that column will be + cast as a string. + + Args: + fname: The name of the file to read from. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + A pandas DataFrame parsed from the file given. + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Read df from either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return pd.read_csv( + path, converters={"subpop": lambda x: str(x)}, skipinitialspace=True + ) + elif path.suffix == ".parquet": + return pd.read_parquet(path, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + ) def command_safe_run(command, command_name="mycommand", fail_on_fail=True): From c51a509aeb677b09c5035efaabf77812e15d2630 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:06:56 -0400 Subject: [PATCH 09/20] Changed extension type from str to Literal * Changed extension param of read_df/write_df from str to Literal[None, "", "csv", "parquet"]. * Added unit tests for `extension=None` for both read_df/write_df. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 9 ++++++--- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 12 ++++++++---- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 12 ++++++++---- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index f59646dc0..2ddf74f24 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -14,7 +14,7 @@ import logging import boto3 from gempyor import file_paths -from typing import List, Dict +from typing import List, Dict, Literal from botocore.exceptions import ClientError from pathlib import Path @@ -26,7 +26,7 @@ def write_df( fname: str | bytes | os.PathLike, df: pd.DataFrame, - extension: str = "", + extension: Literal[None, "", "csv", "parquet"] = "", ) -> None: """Writes a pandas DataFrame without its index to a file. @@ -59,7 +59,10 @@ def write_df( ) -def read_df(fname: str | bytes | os.PathLike, extension: str = "") -> pd.DataFrame: +def read_df( + fname: str | bytes | os.PathLike, + extension: Literal[None, "", "csv", "parquet"] = "", +) -> pd.DataFrame: """Reads a pandas DataFrame from either a CSV or Parquet file. Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index c4c9a14b0..68345ba29 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -1,7 +1,7 @@ import os from tempfile import NamedTemporaryFile from pathlib import Path -from typing import Callable, Any +from typing import Callable, Any, Literal import pytest import pandas as pd @@ -55,6 +55,8 @@ def test_raises_not_implemented_error(self) -> None: [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "csv"), (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), ], @@ -62,7 +64,7 @@ def test_raises_not_implemented_error(self) -> None: def test_read_csv_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests reading a DataFrame from a CSV file. @@ -85,6 +87,8 @@ def test_read_csv_dataframe( [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "parquet"), (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), ], @@ -92,7 +96,7 @@ def test_read_csv_dataframe( def test_read_parquet_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests reading a DataFrame from a Parquet file. @@ -145,7 +149,7 @@ def test_subpop_is_cast_as_str(self) -> None: def _test_read_df( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], suffix: str | None, path_writer: Callable[[os.PathLike, pd.DataFrame], None], ) -> None: diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index aad3bc1c4..22c29240b 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -1,7 +1,7 @@ import os from tempfile import NamedTemporaryFile from pathlib import Path -from typing import Callable, Any +from typing import Callable, Any, Literal import pytest import pandas as pd @@ -40,6 +40,8 @@ def test_raises_not_implemented_error(self) -> None: [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "csv"), (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), ], @@ -47,7 +49,7 @@ def test_raises_not_implemented_error(self) -> None: def test_write_csv_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests writing a DataFrame to a CSV file. @@ -71,6 +73,8 @@ def test_write_csv_dataframe( [ (lambda x: str(x), ""), (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), (lambda x: f"{x.parent}/{x.stem}", "parquet"), (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), ], @@ -78,7 +82,7 @@ def test_write_csv_dataframe( def test_write_parquet_dataframe( self, fname_transformer: Callable[[os.PathLike], Any], - extension: str, + extension: Literal[None, "", "csv", "parquet"], ) -> None: """ Tests writing a DataFrame to a Parquet file. @@ -101,7 +105,7 @@ def _test_write_df( self, fname_transformer: Callable[[os.PathLike], Any], df: pd.DataFrame, - extension: str, + extension: Literal[None, "", "csv", "parquet"], suffix: str | None, path_reader: Callable[[os.PathLike], pd.DataFrame], ) -> None: From 18d9498e5ba7bf434832719f71388cbd85cdb961 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:13:29 -0400 Subject: [PATCH 10/20] Formatted/organized utils * Applied formatting (all whitespace) to read_df/write_df functions. * Reorganized the imports in utils to be clearer. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 45 ++++++++++++----------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 2ddf74f24..e6056b815 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -1,22 +1,25 @@ -import os import datetime import functools +import logging import numbers +import os +from pathlib import Path +import shutil +import subprocess import time +from typing import List, Dict, Literal + +import boto3 +from botocore.exceptions import ClientError import confuse import numpy as np import pandas as pd import pyarrow as pa import scipy.stats import sympy.parsing.sympy_parser -import subprocess -import shutil -import logging -import boto3 + from gempyor import file_paths -from typing import List, Dict, Literal -from botocore.exceptions import ClientError -from pathlib import Path + logger = logging.getLogger(__name__) @@ -24,25 +27,25 @@ def write_df( - fname: str | bytes | os.PathLike, - df: pd.DataFrame, + fname: str | bytes | os.PathLike, + df: pd.DataFrame, extension: Literal[None, "", "csv", "parquet"] = "", ) -> None: """Writes a pandas DataFrame without its index to a file. - + Writes a pandas DataFrame to either a CSV or Parquet file without its index and can - infer which format to use based on the extension given in `fname` or based on + infer which format to use based on the extension given in `fname` or based on explicit `extension`. - + Args: fname: The name of the file to write to. df: A pandas DataFrame whose contents to write, but without its index. extension: A user specified extension to use for the file if not contained in `fname` already. - + Returns: None - + Raises: NotImplementedError: The given output extension is not supported yet. """ @@ -60,24 +63,24 @@ def write_df( def read_df( - fname: str | bytes | os.PathLike, + fname: str | bytes | os.PathLike, extension: Literal[None, "", "csv", "parquet"] = "", ) -> pd.DataFrame: """Reads a pandas DataFrame from either a CSV or Parquet file. - - Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format + + Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format to use based on the extension given in `fname` or based on explicit `extension`. If the file being read is a csv with a column called 'subpop' then that column will be cast as a string. - + Args: fname: The name of the file to read from. extension: A user specified extension to use for the file if not contained in `fname` already. - + Returns: A pandas DataFrame parsed from the file given. - + Raises: NotImplementedError: The given output extension is not supported yet. """ From 0e8241d62101a9ca56029af373befc52f7fac728 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 1 Jul 2024 14:49:28 -0400 Subject: [PATCH 11/20] Typo in NotImplementedError, corresponding tests * Added missing period in NotImplementedError in read_df/write_df, updated corresponding unit tests. Also use path suffix directly instead of given extension. * Added missing test for write_df with provided extension raising NotImplementedError. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 4 ++-- flepimop/gempyor_pkg/tests/utils/test_read_df.py | 4 ++-- flepimop/gempyor_pkg/tests/utils/test_write_df.py | 9 ++++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index e6056b815..075162ce4 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -58,7 +58,7 @@ def write_df( elif path.suffix == ".parquet": return df.to_parquet(path, index=False, engine="pyarrow") raise NotImplementedError( - f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." ) @@ -95,7 +95,7 @@ def read_df( elif path.suffix == ".parquet": return pd.read_parquet(path, engine="pyarrow") raise NotImplementedError( - f"Invalid extension {extension}. Must be 'csv' or 'parquet'" + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." ) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py index 68345ba29..7a0a0c581 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_read_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -38,13 +38,13 @@ def test_raises_not_implemented_error(self) -> None: """ with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: read_df(fname=temp_file.name) with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: fname = temp_file.name[:-4] diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py index 22c29240b..b13e0b948 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_write_df.py +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -30,10 +30,17 @@ def test_raises_not_implemented_error(self) -> None: """ with pytest.raises( expected_exception=NotImplementedError, - match="Invalid extension txt. Must be 'csv' or 'parquet'", + match="Invalid extension txt. Must be 'csv' or 'parquet'.", ) as _: with NamedTemporaryFile(suffix=".txt") as temp_file: write_df(fname=temp_file.name, df=self.sample_df) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + write_df(fname=fname, df=self.sample_df, extension="txt") @pytest.mark.parametrize( "fname_transformer,extension", From 1b5b92d97371a91112ba28209cda1d9c9b46b2e9 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 10:39:13 -0400 Subject: [PATCH 12/20] Added unit tests for list_filenames * Created unit tests in tests/utils/test_list_filenames.py, including tests for searching flat and nested folders. * Added relevant documentation to the tests as well. * Created create_directories_with_files pytest class fixture, might need to be extracted into a more general purpose location later. --- .../tests/utils/test_list_filenames.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_list_filenames.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py new file mode 100644 index 000000000..dea2b4cbc --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py @@ -0,0 +1,166 @@ +"""Unit tests for the `gempyor.utils.list_filenames` function. + +These tests cover scenarios for finding files in both flat and nested directories. +""" + +from collections.abc import Generator +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from gempyor.utils import list_filenames + + +@pytest.fixture(scope="class") +def create_directories_with_files( + request: pytest.FixtureRequest, +) -> Generator[tuple[TemporaryDirectory, TemporaryDirectory], None, None]: + """Fixture to create temporary directories with files for testing. + + This fixture creates two temporary directories: + - A flat directory with files. + - A nested directory with files organized in subdirectories. + + The directories and files are cleaned up after the tests are run. + + Args: + request: The pytest fixture request object. + + Yields: + tuple: A tuple containing the flat and nested TemporaryDirectory objects. + """ + # Create a flat and nested directories + flat_temp_dir = TemporaryDirectory() + nested_temp_dir = TemporaryDirectory() + # Setup flat directory + for file in ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]: + Path(f"{flat_temp_dir.name}/{file}").touch() + # Setup nested directory structure + for file in [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ]: + path = Path(f"{nested_temp_dir.name}/{file}") + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + # Yield + request.cls.flat_temp_dir = flat_temp_dir + request.cls.nested_temp_dir = nested_temp_dir + yield (flat_temp_dir, nested_temp_dir) + # Clean up directories on test end + flat_temp_dir.cleanup() + nested_temp_dir.cleanup() + + +@pytest.mark.usefixtures("create_directories_with_files") +class TestListFilenames: + """Unit tests for the `gempyor.utils.list_filenames` function.""" + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ([], ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + (["hosp"], ["hosp.csv", "hosp.parquet"]), + (["spar"], ["spar.csv", "spar.parquet"]), + ([".parquet"], ["hosp.parquet", "spar.parquet"]), + ([".csv"], ["hosp.csv", "spar.csv"]), + (["hosp", ".csv"], ["hosp.csv"]), + (["spar", ".parquet"], ["spar.parquet"]), + (["hosp", "spar"], []), + ([".csv", ".parquet"], []), + ], + ) + def test_finds_files_in_flat_directory( + self, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a flat directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.flat_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ( + [], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), + ( + ["hpar"], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), + (["seed"], ["seed/001.csv", "seed/002.csv"]), + (["global"], ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + ( + ["001"], + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), + (["hpar", "001"], ["hpar/chimeric/001.parquet", "hpar/global/001.parquet"]), + (["seed", "002"], ["seed/002.csv"]), + ([".tsv"], []), + ], + ) + def test_find_files_in_nested_directory( + self, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a nested directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.nested_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + + def _test_list_filenames( + self, + folder: str, + filters: list[str], + expected_basenames: list[str], + ) -> None: + """Helper method to test `list_filenames`. + + Args: + folder: The directory to search for files. + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + files = list_filenames(folder=folder, filters=filters) + assert len(files) == len(expected_basenames) + basenames = [f.removeprefix(f"{folder}/") for f in files] + assert sorted(basenames) == sorted(expected_basenames) From d963fa407e90ec2b4c8552576db23e7f614aa942 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 11:36:35 -0400 Subject: [PATCH 13/20] Documented & refactored list_filenames * Added more detail to the `gempyor.utils.list_filenames` type hints. * Formatted the documentation to comply with the Google style guide. * Refactored the internals of list_filenames to be single list comprehension instead of a loop of nested conditionals. * Allow `filters` to accept a single string. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 67 ++++++++++++++--------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 075162ce4..afafc5f89 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -309,33 +309,48 @@ def as_random_distribution(self): return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) -def list_filenames(folder: str = ".", filters: list = []) -> list: - """ - return the list of all filename and path in the provided folders. - If filters [list] is provided, then only the files that contains each of the - substrings in filter will be returned. Example to get all hosp file: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp"]) - ``` - and be sure we only get parquet: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp" , ".parquet"]) - ``` +def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str]: + """Return the list of all filenames and paths in the provided folder. + + This function lists all files in the specified folder and its subdirectories. + If filters are provided, only the files containing each of the substrings + in the filters will be returned. + + Example: + To get all files containing "hosp": + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp"], + ) + ``` + + To get only "hosp" files with a ".parquet" extension: + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp", ".parquet"], + ) + ``` + + Args: + folder: The directory to search for files. Defaults to the current directory. + filters: A string or a list of strings to filter filenames. Only files + containing all the provided substrings will be returned. Defaults to an + empty list. + + Returns: + A list of strings representing the paths to the files that match the filters. """ - from pathlib import Path - - fn_list = [] - for f in Path(str(folder)).rglob(f"*"): - if f.is_file(): # not a folder - f = str(f) - if not filters: - fn_list.append(f) - else: - if all(c in f for c in filters): - fn_list.append(str(f)) - else: - pass - return fn_list + filters = list(filters) if not isinstance(filters, list) else filters + filters = filters if len(filters) else [""] + folder = Path(folder) + files = [ + str(file) + for file in folder.rglob("*") + if file.is_file() and all(f in str(file) for f in filters) + ] + return files def rolling_mean_pad(data, window): From 0e710f4b5750ec7a6ddd145b8d32a0fa7a037303 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 12:02:36 -0400 Subject: [PATCH 14/20] Expanded types and tests for list_filenames * Expanded type support for the `folder` arg of `gempyor.utils.list_filnames` to support bytes and os.PathLike. * Vastly expanded test suite to target new supported types. * Corrected bug when filters was given as a string, uncovered by extensive test suite. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 9 ++- .../tests/utils/test_list_filenames.py | 69 +++++++++++++++++-- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index afafc5f89..5c894225a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -309,7 +309,10 @@ def as_random_distribution(self): return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) -def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str]: +def list_filenames( + folder: str | bytes | os.PathLike = ".", + filters: str | list[str] = [], +) -> list[str]: """Return the list of all filenames and paths in the provided folder. This function lists all files in the specified folder and its subdirectories. @@ -342,9 +345,9 @@ def list_filenames(folder: str = ".", filters: str | list[str] = []) -> list[str Returns: A list of strings representing the paths to the files that match the filters. """ - filters = list(filters) if not isinstance(filters, list) else filters + filters = [filters] if not isinstance(filters, list) else filters filters = filters if len(filters) else [""] - folder = Path(folder) + folder = Path(folder.decode() if isinstance(folder, bytes) else folder) files = [ str(file) for file in folder.rglob("*") diff --git a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py index dea2b4cbc..f24c6aaa6 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py +++ b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py @@ -4,6 +4,7 @@ """ from collections.abc import Generator +import os from pathlib import Path from tempfile import TemporaryDirectory @@ -64,11 +65,18 @@ class TestListFilenames: @pytest.mark.parametrize( "filters,expected_basenames", [ + ("", ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), ([], ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + ("hosp", ["hosp.csv", "hosp.parquet"]), (["hosp"], ["hosp.csv", "hosp.parquet"]), + ("spar", ["spar.csv", "spar.parquet"]), (["spar"], ["spar.csv", "spar.parquet"]), + (".parquet", ["hosp.parquet", "spar.parquet"]), ([".parquet"], ["hosp.parquet", "spar.parquet"]), + (".csv", ["hosp.csv", "spar.csv"]), ([".csv"], ["hosp.csv", "spar.csv"]), + (".tsv", []), + ([".tsv"], []), (["hosp", ".csv"], ["hosp.csv"]), (["spar", ".parquet"], ["spar.parquet"]), (["hosp", "spar"], []), @@ -77,7 +85,7 @@ class TestListFilenames: ) def test_finds_files_in_flat_directory( self, - filters: list[str], + filters: str | list[str], expected_basenames: list[str], ) -> None: """Test `list_filenames` in a flat directory. @@ -91,10 +99,31 @@ def test_finds_files_in_flat_directory( filters=filters, expected_basenames=expected_basenames, ) + self._test_list_filenames( + folder=self.flat_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.flat_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) @pytest.mark.parametrize( "filters,expected_basenames", [ + ( + "", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), ( [], [ @@ -106,6 +135,15 @@ def test_finds_files_in_flat_directory( "seed/002.csv", ], ), + ( + "hpar", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), ( ["hpar"], [ @@ -115,8 +153,18 @@ def test_finds_files_in_flat_directory( "hpar/global/002.parquet", ], ), + ("seed", ["seed/001.csv", "seed/002.csv"]), (["seed"], ["seed/001.csv", "seed/002.csv"]), + ("global", ["hpar/global/001.parquet", "hpar/global/002.parquet"]), (["global"], ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + ( + "001", + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), ( ["001"], [ @@ -127,12 +175,14 @@ def test_finds_files_in_flat_directory( ), (["hpar", "001"], ["hpar/chimeric/001.parquet", "hpar/global/001.parquet"]), (["seed", "002"], ["seed/002.csv"]), + (["hpar", "001", "global"], ["hpar/global/001.parquet"]), + (".tsv", []), ([".tsv"], []), ], ) def test_find_files_in_nested_directory( self, - filters: list[str], + filters: str | list[str], expected_basenames: list[str], ) -> None: """Test `list_filenames` in a nested directory. @@ -146,11 +196,21 @@ def test_find_files_in_nested_directory( filters=filters, expected_basenames=expected_basenames, ) + self._test_list_filenames( + folder=self.nested_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.nested_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) def _test_list_filenames( self, - folder: str, - filters: list[str], + folder: str | bytes | os.PathLike, + filters: str | list[str], expected_basenames: list[str], ) -> None: """Helper method to test `list_filenames`. @@ -162,5 +222,6 @@ def _test_list_filenames( """ files = list_filenames(folder=folder, filters=filters) assert len(files) == len(expected_basenames) + folder = folder.decode() if isinstance(folder, bytes) else str(folder) basenames = [f.removeprefix(f"{folder}/") for f in files] assert sorted(basenames) == sorted(expected_basenames) From 2c7a81d78ce8e5cfc19ac9fe2865a02e7b31900d Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 13:28:27 -0400 Subject: [PATCH 15/20] Added unit tests for get_truncated_normal Added unit tests for `gempyor.utils.get_truncated_normal` function. --- .../tests/utils/test_get_truncated_normal.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py new file mode 100644 index 000000000..23e4fad58 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_truncated_normal + + +class TestGetTruncatedNormal: + """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" + + @pytest.mark.parametrize( + "mean,sd,a,b", + [ + (0.0, 1.0, 0.0, 10.0), + (0.0, 2.0, -4.0, 4.0), + (-5.0, 3.0, -5.0, 10.0), + (-3.25, 1.4, -8.74, 4.89), + (0, 1, 0, 10), + (0, 2, -4, 4), + (-5, 3, -5, 10), + ], + ) + def test_construct_distribution( + self, + mean: float | int, + sd: float | int, + a: float | int, + b: float | int, + ) -> None: + """Test the construction of a truncated normal distribution. + + This test checks whether the `get_truncated_normal` function correctly + constructs a truncated normal distribution with the specified parameters. + It verifies that the returned object is an instance of `rv_frozen`, and that + its support and parameters (mean and standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_truncated_normal(mean=mean, sd=sd, a=a, b=b) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, a) + assert np.isclose(upper, b) + assert np.isclose(dist.kwds.get("loc"), mean) + assert np.isclose(dist.kwds.get("scale"), sd) From 17c715d344023fd4e3a0926739e1da2814859d97 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 13:36:33 -0400 Subject: [PATCH 16/20] Documented and refactored get_truncated_normal * Documented `gempyor.utils.get_truncated_normal` including adding appropriate type hints. * Refactored the function lightly for legibility. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 27 ++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 5c894225a..826727c86 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -252,10 +252,31 @@ def as_evaled_expression(self): raise ValueError(f"expected numeric or string expression [got: {value}]") -def get_truncated_normal(*, mean=0, sd=1, a=0, b=10): - "Returns the truncated normal distribution" +def get_truncated_normal( + mean: float | int = 0, + sd: float | int = 1, + a: float | int = 0, + b: float | int = 10, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a truncated normal distribution. + + This function constructs a truncated normal distribution with the specified + mean, standard deviation, and bounds. The truncated normal distribution is + a normal distribution bounded within the interval [a, b]. - return scipy.stats.truncnorm((a - mean) / sd, (b - mean) / sd, loc=mean, scale=sd) + Args: + mean: The mean of the truncated normal distribution. Defaults to 0. + sd: The standard deviation of the truncated normal distribution. Defaults to 1. + a: The lower bound of the truncated normal distribution. Defaults to 0. + b: The upper bound of the truncated normal distribution. Defaults to 10. + + Returns: + rv_frozen: A frozen instance of the truncated normal distribution with the + specified parameters. + """ + lower = (a - mean) / sd + upper = (b - mean) / sd + return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd) def get_log_normal(meanlog, sdlog): From 2aaa6fbb8e7dd27f8e29c41ca9015920ce03130c Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 14:30:15 -0400 Subject: [PATCH 17/20] Added unit tests for get_log_normal Added unit tests for the `gempyor.utils.get_log_normal` function. --- .../tests/utils/test_get_log_normal.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py new file mode 100644 index 000000000..367a7f550 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_log_normal + + +class TestGetLogNormal: + """Unit tests for the `gempyor.utils.get_log_normal` function.""" + + @pytest.mark.parametrize( + "meanlog,sdlog", + [ + (1.0, 1.0), + (0.0, 2.0), + (10.0, 30.0), + (0.33, 4.56), + (9.87, 4.21), + (1, 1), + (0, 2), + (10, 30), + ], + ) + def test_construct_distribution( + self, + meanlog: float | int, + sdlog: float | int, + ) -> None: + """Test the construction of a log normal distribution. + + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and + parameters (log mean and log standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_log_normal(meanlog=meanlog, sdlog=sdlog) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, 0.0) + assert np.isclose(upper, np.inf) + assert np.isclose(dist.kwds.get("s"), sdlog) + assert np.isclose(dist.kwds.get("scale"), np.exp(meanlog)) + assert np.isclose(dist.kwds.get("loc"), 0.0) From c09b9535ddfc364702bb048fad168fceab71ac3d Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Fri, 5 Jul 2024 14:35:47 -0400 Subject: [PATCH 18/20] Documented get_log_normal Added documentation for `gempyor.utils.get_log_normal` including adding appropriate type hints. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 826727c86..98c54a1f1 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -279,8 +279,23 @@ def get_truncated_normal( return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd) -def get_log_normal(meanlog, sdlog): - "Returns the log normal distribution" +def get_log_normal( + meanlog: float | int, + sdlog: float | int, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a log normal distribution. + + This function constructs a log normal distribution with the specified + log mean and log standard deviation. + + Args: + meanlog: The log of the mean of the log normal distribution. + sdlog: The log of the standard deviation of the log normal distribution. + + Returns: + rv_frozen: A frozen instance of the log normal distribution with the + specified parameters. + """ return scipy.stats.lognorm(s=sdlog, scale=np.exp(meanlog), loc=0) From e073f883fa028467386f7292f4cf3f2703a82358 Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 8 Jul 2024 10:33:06 -0400 Subject: [PATCH 19/20] Added unit tests for rolling_mean_pad * Added unit tests for `gempyor.utils.rolling_mean_pad` in `tests/utils/test_rolling_mean_pad.py`. * Wrote a reference implementation of `rolling_mean_pad` for comparison purposes. --- .../tests/utils/test_rolling_mean_pad.py | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py diff --git a/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py new file mode 100644 index 000000000..94be3394a --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py @@ -0,0 +1,149 @@ +import numpy as np +import numpy.typing as npt +import pytest + +from gempyor.utils import rolling_mean_pad + + +class TestRollingMeanPad: + """Unit tests for the `gempyor.utils.rolling_mean_pad` function.""" + + # Test data for various matrix configurations + test_data = { + # 1x1 matrices + "one_by_one_const": np.array([[1.0]]), + "one_by_one_nan": np.array([[np.nan]]), + "one_by_one_rand": np.random.uniform(size=(1, 1)), + # 1xN matrices + "one_by_many_const": np.arange(start=1.0, stop=6.0).reshape((1, 5)), + "one_by_many_nan": np.repeat(np.nan, 5).reshape((1, 5)), + "one_by_many_rand": np.random.uniform(size=(1, 5)), + # Mx1 matrices + "many_by_one_const": np.arange(start=3.0, stop=9.0).reshape((6, 1)), + "many_by_one_nan": np.repeat(np.nan, 6).reshape((6, 1)), + "many_by_one_rand": np.random.uniform(size=(6, 1)), + # MxN matrices + "many_by_many_const": np.arange(start=1.0, stop=49.0).reshape((12, 4)), + "many_by_many_nan": np.repeat(np.nan, 48).reshape((12, 4)), + "many_by_many_rand": np.random.uniform(size=(12, 4)), + } + + @pytest.mark.parametrize( + "test_data_name,expected_shape,window,put_nans", + [ + # 1x1 matrices + ("one_by_one_const", (1, 1), 3, []), + ("one_by_one_const", (1, 1), 4, []), + ("one_by_one_nan", (1, 1), 3, []), + ("one_by_one_nan", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 3, []), + ("one_by_one_rand", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 5, []), + ("one_by_one_rand", (1, 1), 6, []), + # 1xN matrices + ("one_by_many_const", (1, 5), 3, []), + ("one_by_many_const", (1, 5), 4, []), + ("one_by_many_nan", (1, 5), 3, []), + ("one_by_many_nan", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 3, []), + ("one_by_many_rand", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 5, []), + ("one_by_many_rand", (1, 5), 6, []), + # Mx1 matrices + ("many_by_one_const", (6, 1), 3, []), + ("many_by_one_const", (6, 1), 4, []), + ("many_by_one_nan", (6, 1), 3, []), + ("many_by_one_nan", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 3, []), + ("many_by_one_rand", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 5, []), + ("many_by_one_rand", (6, 1), 6, []), + # MxN matrices + ("many_by_many_const", (12, 4), 3, []), + ("many_by_many_const", (12, 4), 4, []), + ("many_by_many_const", (12, 4), 5, []), + ("many_by_many_const", (12, 4), 6, []), + ("many_by_many_nan", (12, 4), 3, []), + ("many_by_many_nan", (12, 4), 4, []), + ("many_by_many_nan", (12, 4), 5, []), + ("many_by_many_nan", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 3, []), + ("many_by_many_rand", (12, 4), 4, []), + ("many_by_many_rand", (12, 4), 5, []), + ("many_by_many_rand", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 7, []), + ("many_by_many_rand", (12, 4), 8, []), + ("many_by_many_rand", (12, 4), 9, []), + ("many_by_many_rand", (12, 4), 10, []), + ("many_by_many_rand", (12, 4), 11, []), + ("many_by_many_rand", (12, 4), 12, []), + ("many_by_many_rand", (12, 4), 13, []), + ("many_by_many_rand", (12, 4), 14, []), + ("many_by_many_rand", (12, 4), 3, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 4, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 5, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 6, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 7, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 8, [(2, 2), (4, 4)]), + ], + ) + def test_rolling_mean_pad( + self, + test_data_name: str, + expected_shape: tuple[int, int], + window: int, + put_nans: list[tuple[int, int]], + ) -> None: + """Tests `rolling_mean_pad` function with various configurations of input data. + + Args: + test_data_name: The name of the test data set to use. + expected_shape: The expected shape of the output array. + window: The size of the rolling window. + put_nans: A list of indices to insert NaNs into the input data. + + Raises: + AssertionError: If the shape or contents of the output do not match the + expected values. + """ + test_data = self.test_data.get(test_data_name).copy() + if put_nans: + np.put(test_data, put_nans, np.nan) + rolling_mean_data = rolling_mean_pad(test_data, window) + rolling_mean_reference = self._rolling_mean_pad_reference(test_data, window) + assert rolling_mean_data.shape == expected_shape + assert np.isclose( + rolling_mean_data, rolling_mean_reference, equal_nan=True + ).all() + + def _rolling_mean_pad_reference( + self, data: npt.NDArray[np.number], window: int + ) -> npt.NDArray[np.number]: + """Generates a reference rolling mean with padding. + + This implementation should match the `gempyor.utils.rolling_mean_pad` + implementation, but is written for readability. As a result this + reference implementation is extremely slow. + + Args: + data: The input array for which to compute the rolling mean. + window: The size of the rolling window. + + Returns: + An array of the same shape as `data` containing the rolling mean values. + """ + # Setup + rows, cols = data.shape + output = np.zeros((rows, cols), dtype=data.dtype) + # Slow but intuitive triple loop + for i in range(rows): + for j in range(cols): + # If the last row on an even window, change the window to be one less, + # so 4 -> 3, but 5 -> 5. + sub_window = window - 1 if window % 2 == 0 and i == rows - 1 else window + weight = 1.0 / sub_window + for l in range(-((sub_window - 1) // 2), 1 + (sub_window // 2)): + i_star = min(max(i + l, 0), rows - 1) + output[i, j] += weight * data[i_star, j] + # Done + return output From a91fc6df031181b0e537b88e0ca60f9f3df287db Mon Sep 17 00:00:00 2001 From: Timothy Willard Date: Mon, 8 Jul 2024 14:15:57 -0400 Subject: [PATCH 20/20] Documented `rolling_mean_pad` * Added type hints for `gempyor.utils.rolling_mean_pad`. * Expanded the existing docstring and included an example. --- flepimop/gempyor_pkg/src/gempyor/utils.py | 37 ++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 98c54a1f1..85820d5b2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -13,6 +13,7 @@ from botocore.exceptions import ClientError import confuse import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa import scipy.stats @@ -392,16 +393,44 @@ def list_filenames( return files -def rolling_mean_pad(data, window): +def rolling_mean_pad( + data: npt.NDArray[np.number], + window: int, +) -> npt.NDArray[np.number]: """ - Calculates rolling mean with centered window and pads the edges. + Calculates the column-wise rolling mean with centered window. Args: - data: A NumPy array !!! shape must be (n_days, nsubpops). + data: A two dimensional numpy array, typically the row dimension is time and the + column dimension is subpop. window: The window size for the rolling mean. Returns: - A NumPy array with the padded rolling mean (n_days, nsubpops). + A two dimensional numpy array that is the same shape as `data`. + + Examples: + Below is a brief set of examples showcasing how to smooth a metric, like + hospitalizations, using this function. + + >>> import numpy as np + >>> from gempyor.utils import rolling_mean_pad + >>> hospitalizations = np.arange(1., 29.).reshape((7, 4)) + >>> hospitalizations + array([[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.], + [25., 26., 27., 28.]]) + >>> rolling_mean_pad(hospitalizations, 5) + array([[ 3.4, 4.4, 5.4, 6.4], + [ 5.8, 6.8, 7.8, 8.8], + [ 9. , 10. , 11. , 12. ], + [13. , 14. , 15. , 16. ], + [17. , 18. , 19. , 20. ], + [20.2, 21.2, 22.2, 23.2], + [22.6, 23.6, 24.6, 25.6]]) """ padding_size = (window - 1) // 2 padded_data = np.pad(data, ((padding_size, padding_size), (0, 0)), mode="edge")