Skip to content

Commit

Permalink
ironing and bug squasing
Browse files Browse the repository at this point in the history
  • Loading branch information
jcblemai committed Apr 19, 2024
1 parent 513e9cc commit eaa0f93
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 9 deletions.
29 changes: 24 additions & 5 deletions examples/simple_usa_statelevel/simple_usa_statelevel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,36 @@ seir_modifiers:
period_start_date: 2023-11-01
period_end_date: 2024-02-01
subpop: "all"
value:
distribution: uniform
low: 1
high: 1.4
value:
distribution: truncnorm
mean: 1.3
sd: 0.4
a: 0.2
b: 3
perturbation:
distribution: truncnorm
mean: 0
sd: 0.025
a: -1
b: 1
Ro_sunny:
method: SinglePeriodModifier
parameter: Ro
period_start_date: 2024-02-01
period_end_date: 2024-04-01
subpop: "all"
value: 1.3
value:
distribution: truncnorm
mean: 1.3
sd: 0.4
a: 0.1
b: 10
perturbation:
distribution: truncnorm
mean: 0
sd: 0.025
a: -1
b: 1
Ro_all:
method: StackedModifier
modifiers: ["Ro_season","Ro_sunny"]
Expand Down
23 changes: 23 additions & 0 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@
import pyarrow.parquet as pq
import xarray as xr
import numba as nb
import copy


def emcee_logprob(proposal, modinf, inferpar, loss, static_sim_arguments, save=False):
if not inferpar.check_in_bound(proposal=proposal):
print("OUT OF BOUND!!")
return -np.inf

snpi_df_mod, hnpi_df_mod = inferpar.inject_proposal(proposal=proposal, snpi_df = static_sim_arguments["snpi_df_ref"], hnpi_df = static_sim_arguments["hnpi_df_ref"])

ss = copy.deepcopy(static_sim_arguments)
ss["snpi_df_in"] = snpi_df_mod
ss["hnpi_df_in"] = hnpi_df_mod
del ss["snpi_df_ref"]
del ss["hnpi_df_ref"]


outcomes_df = simulation_atomic(**ss, modinf=modinf, save=save)

ll_total, logloss, regularizations = loss.compute_logloss(model_df=outcomes_df, modinf=modinf)
print(f"llik is {ll_total}")

return ll_total


# TODO: there is way to many of these functions, merge with the R interface.py implementation to avoid code duplication
Expand Down
16 changes: 13 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/inference_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def print_summary(self):
f" >> affected subpop: {self.subpops[p_idx]}"
)

def __str__(self) -> str:
from collections import Counter
this_str = f"InferenceParameters: with {self.get_dim()} parameters: \n"
for key, value in Counter(self.ptypes).items():
this_str += f" {key}: {value} parameters\n"

return this_str

def get_dim(self):
return len(self.pnames)

Expand Down Expand Up @@ -152,7 +160,7 @@ def hit_ubs(self, proposal) -> np.ndarray:
"""
return np.array((proposal > self.ubs))

def inject_proposal(self, proposal, hnpi_df=None, snpi_df=None):
def inject_proposal(self, proposal, snpi_df=None, hnpi_df=None,):
"""
Injects the proposal into model inputs, at the right place.
Expand All @@ -167,14 +175,16 @@ def inject_proposal(self, proposal, hnpi_df=None, snpi_df=None):
snpi_df_mod = snpi_df.copy(deep=True)
hnpi_df_mod = hnpi_df.copy(deep=True)

# Ideally this should lie in each submodules, e.g NPI.inject, parameter.inject

for p_idx in range(self.get_dim()):
if self.ptypes[p_idx] == "snpi":
if self.ptypes[p_idx] == "seir_modifiers":
snpi_df_mod.loc[
(snpi_df_mod["modifier_name"] == self.pnames[p_idx])
& (snpi_df_mod["subpop"] == self.subpops[p_idx]),
"value",
] = proposal[p_idx]
elif self.ptypes[p_idx] == "hnpi":
elif self.ptypes[p_idx] == "outcome_modifiers":
hnpi_df_mod.loc[
(hnpi_df_mod["modifier_name"] == self.pnames[p_idx])
& (hnpi_df_mod["subpop"] == self.subpops[p_idx]),
Expand Down
1 change: 0 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs):
# scale the data so that the lastest X items are more important
last_n = kwargs.get("last_n", 4)
mult = kwargs.get("mult", 2)
print("forecast", last_n, mult)

last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None)))

Expand Down

0 comments on commit eaa0f93

Please sign in to comment.