Skip to content

Commit

Permalink
Correct formula for poisson log-likelihood
Browse files Browse the repository at this point in the history
Per this comment #375 (comment)
corrected the formula for the poisson log-likelihood, including swapping
the `ymodel` and `ydata` so it matches prior poisson likelihood. Added a
unit test case to demonstrate this change is to restore current
behavior.
  • Loading branch information
TimothyWillard committed Dec 9, 2024
1 parent b32cad8 commit 8e5baa3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray:
"""

dist_map = {
"pois": lambda ymodel, ydata: -(ymodel + 1)
+ ydata * np.log(ymodel + 1)
"pois": lambda ydata, ymodel: -ymodel
+ (ydata * np.log(ymodel))
- gammaln(ydata + 1),
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# OLD: # TODO: Swap out in favor of NEW
Expand Down
48 changes: 48 additions & 0 deletions flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,59 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput:
)


def simple_valid_factory_with_pois() -> MockStatisticInput:
data_coords = {
"date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)),
"subpop": ["01", "02", "03"],
}
data_dim = [len(v) for v in data_coords.values()]
model_data = xr.Dataset(
data_vars={
"incidH": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
"incidD": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
},
coords=data_coords,
)
gt_data = xr.Dataset(
data_vars={
"incidH": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
"incidD": (
list(data_coords.keys()),
np.random.poisson(lam=20.0, size=data_dim),
),
},
coords=data_coords,
)
return MockStatisticInput(
"total_hospitalizations",
{
"name": "sum_hospitalizations",
"sim_var": "incidH",
"data_var": "incidH",
"remove_na": True,
"add_one": True,
"likelihood": {"dist": "pois"},
},
model_data=model_data,
gt_data=gt_data,
)


all_valid_factories = [
(simple_valid_factory),
(simple_valid_resample_factory),
(simple_valid_resample_factory),
(simple_valid_resample_and_scale_factory),
(simple_valid_factory_with_pois),
]


Expand Down

0 comments on commit 8e5baa3

Please sign in to comment.