Skip to content

Commit

Permalink
Merge pull request #420 from HopkinsIDD/hotfix/correct-poisson-statis…
Browse files Browse the repository at this point in the history
…tics-pmf

* Correct formula for poisson log-likelihood

Per this comment #375 (comment)
corrected the formula for the poisson log-likelihood, including swapping
the `ymodel` and `ydata` so it matches prior poisson likelihood. Added a
unit test case to demonstrate this change is to restore current
behavior.

* Add test case with poisson and zeros

Added a test case with poisson log-likelihood and zero valued data with
the `zero_to_one` flag set to `True`.

* Simplify poisson LL tests with only one variable
  • Loading branch information
TimothyWillard authored Dec 11, 2024
2 parents a104210 + 14d9bf3 commit b1d8404
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray:
"""

dist_map = {
"pois": lambda ymodel, ydata: -(ymodel + 1)
+ ydata * np.log(ymodel + 1)
"pois": lambda ydata, ymodel: -ymodel
+ (ydata * np.log(ymodel))
- gammaln(ydata + 1),
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# OLD: # TODO: Swap out in favor of NEW
Expand Down
93 changes: 91 additions & 2 deletions flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,87 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput:
)


def simple_valid_factory_with_pois() -> MockStatisticInput:
data_coords = {
"date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)),
"subpop": ["01", "02", "03"],
}
data_dim = [len(v) for v in data_coords.values()]
model_data = xr.Dataset(
data_vars={
"incidH": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
},
coords=data_coords,
)
gt_data = xr.Dataset(
data_vars={
"incidH": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
},
coords=data_coords,
)
return MockStatisticInput(
"total_hospitalizations",
{
"name": "sum_hospitalizations",
"sim_var": "incidH",
"data_var": "incidH",
"remove_na": True,
"add_one": True,
"likelihood": {"dist": "pois"},
},
model_data=model_data,
gt_data=gt_data,
)


def simple_valid_factory_with_pois_with_some_zeros() -> MockStatisticInput:
mock_input = simple_valid_factory_with_pois()

mock_input.config["zero_to_one"] = True

mock_input.model_data["incidH"].loc[
{
"date": mock_input.model_data.coords["date"][0],
"subpop": mock_input.model_data.coords["subpop"][0],
}
] = 0

mock_input.gt_data["incidH"].loc[
{
"date": mock_input.gt_data.coords["date"][2],
"subpop": mock_input.gt_data.coords["subpop"][2],
}
] = 0

mock_input.model_data["incidH"].loc[
{
"date": mock_input.model_data.coords["date"][1],
"subpop": mock_input.model_data.coords["subpop"][1],
}
] = 0
mock_input.gt_data["incidH"].loc[
{
"date": mock_input.gt_data.coords["date"][1],
"subpop": mock_input.gt_data.coords["subpop"][1],
}
] = 0

return mock_input


all_valid_factories = [
(simple_valid_factory),
(simple_valid_resample_factory),
(simple_valid_resample_factory),
(simple_valid_resample_and_scale_factory),
(simple_valid_factory_with_pois),
(simple_valid_factory_with_pois_with_some_zeros),
]


Expand Down Expand Up @@ -501,8 +577,21 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None:
assert np.allclose(
log_likelihood.values,
scipy.stats.poisson.logpmf(
mock_inputs.gt_data[mock_inputs.config["data_var"]].values,
mock_inputs.model_data[mock_inputs.config["data_var"]].values,
np.where(
mock_inputs.config.get("zero_to_one", False)
& (mock_inputs.gt_data[mock_inputs.config["data_var"]].values == 0),
1,
mock_inputs.gt_data[mock_inputs.config["data_var"]].values,
),
np.where(
mock_inputs.config.get("zero_to_one", False)
& (
mock_inputs.model_data[mock_inputs.config["data_var"]].values
== 0
),
1,
mock_inputs.model_data[mock_inputs.config["data_var"]].values,
),
),
)
elif dist_name in {"norm", "norm_cov"}:
Expand Down

0 comments on commit b1d8404

Please sign in to comment.