diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 7b62c608e..3d85c2693 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -244,8 +244,8 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: """ dist_map = { - "pois": lambda ymodel, ydata: -(ymodel + 1) - + ydata * np.log(ymodel + 1) + "pois": lambda ydata, ymodel: -ymodel + + (ydata * np.log(ymodel)) - gammaln(ydata + 1), # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # OLD: # TODO: Swap out in favor of NEW diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 781820165..008b5073d 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -228,11 +228,87 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput: ) +def simple_valid_factory_with_pois() -> MockStatisticInput: + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": ( + list(data_coords.keys()), + np.random.poisson(lam=20.0, size=data_dim), + ), + }, + coords=data_coords, + ) + gt_data = xr.Dataset( + data_vars={ + "incidH": ( + list(data_coords.keys()), + np.random.poisson(lam=20.0, size=data_dim), + ), + }, + coords=data_coords, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + +def simple_valid_factory_with_pois_with_some_zeros() -> MockStatisticInput: + mock_input = simple_valid_factory_with_pois() + + mock_input.config["zero_to_one"] = True + + mock_input.model_data["incidH"].loc[ + { + "date": mock_input.model_data.coords["date"][0], + "subpop": mock_input.model_data.coords["subpop"][0], + } + ] = 0 + + mock_input.gt_data["incidH"].loc[ + { + "date": mock_input.gt_data.coords["date"][2], + "subpop": mock_input.gt_data.coords["subpop"][2], + } + ] = 0 + + mock_input.model_data["incidH"].loc[ + { + "date": mock_input.model_data.coords["date"][1], + "subpop": mock_input.model_data.coords["subpop"][1], + } + ] = 0 + mock_input.gt_data["incidH"].loc[ + { + "date": mock_input.gt_data.coords["date"][1], + "subpop": mock_input.gt_data.coords["subpop"][1], + } + ] = 0 + + return mock_input + + all_valid_factories = [ (simple_valid_factory), (simple_valid_resample_factory), (simple_valid_resample_factory), (simple_valid_resample_and_scale_factory), + (simple_valid_factory_with_pois), + (simple_valid_factory_with_pois_with_some_zeros), ] @@ -501,8 +577,21 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: assert np.allclose( log_likelihood.values, scipy.stats.poisson.logpmf( - mock_inputs.gt_data[mock_inputs.config["data_var"]].values, - mock_inputs.model_data[mock_inputs.config["data_var"]].values, + np.where( + mock_inputs.config.get("zero_to_one", False) + & (mock_inputs.gt_data[mock_inputs.config["data_var"]].values == 0), + 1, + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + ), + np.where( + mock_inputs.config.get("zero_to_one", False) + & ( + mock_inputs.model_data[mock_inputs.config["data_var"]].values + == 0 + ), + 1, + mock_inputs.model_data[mock_inputs.config["data_var"]].values, + ), ), ) elif dist_name in {"norm", "norm_cov"}: