diff --git a/flepimop/gempyor_pkg/src/gempyor/parameters.py b/flepimop/gempyor_pkg/src/gempyor/parameters.py index 8e0bdaa25..3d38b2c1f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/parameters.py +++ b/flepimop/gempyor_pkg/src/gempyor/parameters.py @@ -106,6 +106,10 @@ def __init__( else: self.pdata[pn]["stacked_modifier_method"] = "product" logging.debug(f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs") + + if self.pconfig[pn]["rolling_mean_windows"].exists(): + self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn]["rolling_mean_windows"].get() + self.stacked_modifier_method[self.pdata[pn]["stacked_modifier_method"]].append(pn.lower()) logging.debug(f"We have {self.npar} parameter: {self.pnames}") @@ -192,11 +196,14 @@ def parameters_reduce(self, p_draw: ndarray, npi: object) -> ndarray: """ p_reduced = copy.deepcopy(p_draw) if npi is not None: - for idx, pn in enumerate(self.pnames): + for idx, pn in enumerate(self.pnames): p_reduced[idx] = NPI.reduce_parameter( parameter=p_draw[idx], modification=npi.getReduction(pn.lower()), method=self.pdata[pn]["stacked_modifier_method"], ) + # apply rolling mean if specified + if "rolling_mean_windows" in self.pdata[pn]: + p_reduced[idx] = utils.rolling_mean_pad(data = p_reduced[idx], window=self.pdata[pn]["rolling_mean_windows"]) return p_reduced diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index c294719d3..4e59761f2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -7,7 +7,7 @@ import xarray as xr from . import NPI, model_info, steps_rk4 -from .utils import Timer, aws_disk_diagnosis, read_df +from .utils import Timer, print_disk_diagnosis, read_df import logging logger = logging.getLogger(__name__) @@ -388,7 +388,7 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi): def write_seir(sim_id, modinf, states): - # aws_disk_diagnosis() + # print_disk_diagnosis() out_df = states2Df(modinf, states) modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df) diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index b9ddd1ac1..1076d09bc 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -272,7 +272,24 @@ def list_filenames(folder: str = ".", filters: list = []) -> list: return fn_list -def aws_disk_diagnosis(): +def rolling_mean_pad(data, window): + """ + Calculates rolling mean with centered window and pads the edges. + + Args: + data: A NumPy array. + window: The window size for the rolling mean. + + Returns: + A NumPy array with the padded rolling mean. + """ + padding_size = (window - 1) // 2 + padded_data = np.pad(data, padding_size, mode='edge') + return np.convolve(padded_data, np.ones(window) / window, mode='valid') + + + +def print_disk_diagnosis(): import os from os import path from shutil import disk_usage diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 694a7296f..eedb3c4f3 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -68,8 +68,8 @@ def test_Timer_with_statement_success(): time.sleep(1) -def test_aws_disk_diagnosis_success(): - utils.aws_disk_diagnosis() +def test_print_disk_diagnosis_success(): + utils.print_disk_diagnosis() def test_profile_success():