diff --git a/examples/simple_usa_statelevel/simple_usa_statelevel.yml b/examples/simple_usa_statelevel/simple_usa_statelevel.yml index 1b536d7c3..e3eddea60 100644 --- a/examples/simple_usa_statelevel/simple_usa_statelevel.yml +++ b/examples/simple_usa_statelevel/simple_usa_statelevel.yml @@ -53,17 +53,36 @@ seir_modifiers: period_start_date: 2023-11-01 period_end_date: 2024-02-01 subpop: "all" - value: - distribution: uniform - low: 1 - high: 1.4 + value: + distribution: truncnorm + mean: 1.3 + sd: 0.4 + a: 0.2 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -1 + b: 1 Ro_sunny: method: SinglePeriodModifier parameter: Ro period_start_date: 2024-02-01 period_end_date: 2024-04-01 subpop: "all" - value: 1.3 + value: + distribution: truncnorm + mean: 1.3 + sd: 0.4 + a: 0.1 + b: 10 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -1 + b: 1 Ro_all: method: StackedModifier modifiers: ["Ro_season","Ro_sunny"] diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index 7f36dd087..eff6ba554 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -12,6 +12,29 @@ import pyarrow.parquet as pq import xarray as xr import numba as nb +import copy + + +def emcee_logprob(proposal, modinf, inferpar, loss, static_sim_arguments, save=False): + if not inferpar.check_in_bound(proposal=proposal): + print("OUT OF BOUND!!") + return -np.inf + + snpi_df_mod, hnpi_df_mod = inferpar.inject_proposal(proposal=proposal, snpi_df = static_sim_arguments["snpi_df_ref"], hnpi_df = static_sim_arguments["hnpi_df_ref"]) + + ss = copy.deepcopy(static_sim_arguments) + ss["snpi_df_in"] = snpi_df_mod + ss["hnpi_df_in"] = hnpi_df_mod + del ss["snpi_df_ref"] + del ss["hnpi_df_ref"] + + + outcomes_df = simulation_atomic(**ss, modinf=modinf, save=save) + + ll_total, logloss, regularizations = loss.compute_logloss(model_df=outcomes_df, modinf=modinf) + print(f"llik is {ll_total}") + + return ll_total # TODO: there is way to many of these functions, merge with the R interface.py implementation to avoid code duplication diff --git a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py index de7ce83d8..aeb56b461 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py @@ -103,6 +103,14 @@ def print_summary(self): f" >> affected subpop: {self.subpops[p_idx]}" ) + def __str__(self) -> str: + from collections import Counter + this_str = f"InferenceParameters: with {self.get_dim()} parameters: \n" + for key, value in Counter(self.ptypes).items(): + this_str += f" {key}: {value} parameters\n" + + return this_str + def get_dim(self): return len(self.pnames) @@ -152,7 +160,7 @@ def hit_ubs(self, proposal) -> np.ndarray: """ return np.array((proposal > self.ubs)) - def inject_proposal(self, proposal, hnpi_df=None, snpi_df=None): + def inject_proposal(self, proposal, snpi_df=None, hnpi_df=None,): """ Injects the proposal into model inputs, at the right place. @@ -167,14 +175,16 @@ def inject_proposal(self, proposal, hnpi_df=None, snpi_df=None): snpi_df_mod = snpi_df.copy(deep=True) hnpi_df_mod = hnpi_df.copy(deep=True) + # Ideally this should lie in each submodules, e.g NPI.inject, parameter.inject + for p_idx in range(self.get_dim()): - if self.ptypes[p_idx] == "snpi": + if self.ptypes[p_idx] == "seir_modifiers": snpi_df_mod.loc[ (snpi_df_mod["modifier_name"] == self.pnames[p_idx]) & (snpi_df_mod["subpop"] == self.subpops[p_idx]), "value", ] = proposal[p_idx] - elif self.ptypes[p_idx] == "hnpi": + elif self.ptypes[p_idx] == "outcome_modifiers": hnpi_df_mod.loc[ (hnpi_df_mod["modifier_name"] == self.pnames[p_idx]) & (hnpi_df_mod["subpop"] == self.subpops[p_idx]), diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 8c233dc13..adc28a9f3 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -65,7 +65,6 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs): # scale the data so that the lastest X items are more important last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) - print("forecast", last_n, mult) last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None)))