Skip to content

Commit

Permalink
new way of smoothing parameter to test
Browse files Browse the repository at this point in the history
  • Loading branch information
jcblemai committed Apr 25, 2024
1 parent ac076c0 commit be78600
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
9 changes: 8 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def __init__(
else:
self.pdata[pn]["stacked_modifier_method"] = "product"
logging.debug(f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs")

if self.pconfig[pn]["rolling_mean_windows"].exists():
self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn]["rolling_mean_windows"].get()

self.stacked_modifier_method[self.pdata[pn]["stacked_modifier_method"]].append(pn.lower())

logging.debug(f"We have {self.npar} parameter: {self.pnames}")
Expand Down Expand Up @@ -192,11 +196,14 @@ def parameters_reduce(self, p_draw: ndarray, npi: object) -> ndarray:
"""
p_reduced = copy.deepcopy(p_draw)
if npi is not None:
for idx, pn in enumerate(self.pnames):
for idx, pn in enumerate(self.pnames):
p_reduced[idx] = NPI.reduce_parameter(
parameter=p_draw[idx],
modification=npi.getReduction(pn.lower()),
method=self.pdata[pn]["stacked_modifier_method"],
)
# apply rolling mean if specified
if "rolling_mean_windows" in self.pdata[pn]:
p_reduced[idx] = utils.rolling_mean_pad(data = p_reduced[idx], window=self.pdata[pn]["rolling_mean_windows"])

return p_reduced
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xarray as xr

from . import NPI, model_info, steps_rk4
from .utils import Timer, aws_disk_diagnosis, read_df
from .utils import Timer, print_disk_diagnosis, read_df
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -388,7 +388,7 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi):


def write_seir(sim_id, modinf, states):
# aws_disk_diagnosis()
# print_disk_diagnosis()
out_df = states2Df(modinf, states)
modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df)

Expand Down
19 changes: 18 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,24 @@ def list_filenames(folder: str = ".", filters: list = []) -> list:
return fn_list


def aws_disk_diagnosis():
def rolling_mean_pad(data, window):
"""
Calculates rolling mean with centered window and pads the edges.
Args:
data: A NumPy array.
window: The window size for the rolling mean.
Returns:
A NumPy array with the padded rolling mean.
"""
padding_size = (window - 1) // 2
padded_data = np.pad(data, padding_size, mode='edge')
return np.convolve(padded_data, np.ones(window) / window, mode='valid')



def print_disk_diagnosis():
import os
from os import path
from shutil import disk_usage
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_Timer_with_statement_success():
time.sleep(1)


def test_aws_disk_diagnosis_success():
utils.aws_disk_diagnosis()
def test_print_disk_diagnosis_success():
utils.print_disk_diagnosis()


def test_profile_success():
Expand Down

0 comments on commit be78600

Please sign in to comment.