diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index ea2cc72a3..ea10e7aa1 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -1,29 +1,69 @@ -import xarray as xr -import pandas as pd -import numpy as np +""" +Abstractions for interacting with output statistic configurations. + +This module provides the `Statistic` class which represents a entry in the inference -> +statistics section. +""" + +__all__ = ["Statistic"] + + import confuse +import numpy as np import scipy.stats +import xarray as xr class Statistic: """ - A statistic is a function that takes two time series and returns a scalar value. - It applies resample, scale, and regularization to the data before computing the statistic's log-loss. - Configuration: - - sim_var: the variable in the simulation data - - data_var: the variable in the ground truth data - - resample: resample the data before computing the statistic - - freq: the frequency to resample the data to - - aggregator: the aggregation function to use - - skipna: whether to skip NA values - - regularize: apply a regularization term to the data before computing the statistic - - # SkipNA is False by default, which results in NA values broadcasting when resampling (e.g a NA withing a sum makes the whole sum a NA) - # if True, then NA are replaced with 0 (for sum), 1 for product, ... - # In doubt, plot stat.plot_transformed() to see the effect of the resampling + Encapsulates logic for representing/implementing output statistic configurations. + + A statistic is a function that takes two time series and returns a scalar value. It + applies resample, scale, and regularization to the data before computing the + statistic's log-loss. + + Attributes: + data_var: The variable in the ground truth data. + dist: The name of the distribution to use for calculating log-likelihood. + name: The human readable name for the statistic given during instantiation. + params: Distribution parameters used in the log-likelihood calculation and + dependent on `dist`. + regularizations: Regularization functions that are added to the log loss of this + statistic. + resample: If the data should be resampled before computing the statistic. + Defaults to `False`. + resample_aggregator_name: The name of the aggregation function to use. This + attribute is not set when a "resample" section is not defined in the + `statistic_config` arg. + resample_freq: The frequency to resample the data to if the `resample` attribute + is `True`. This attribute is not set when a "resample" section is not + defined in the `statistic_config` arg. + resample_skipna: If NAs should be skipped when aggregating. `False` by default. + This attribute is not set when a "resample" section is not defined in the + `statistic_config` arg. + scale: If the data should be rescaled before computing the statistic. + scale_func: The function to use when rescaling the data. Can be any function + exported by `numpy`. This attribute is not set when a "scale" value is not + defined in the `statistic_config` arg. + zero_to_one: Should non-zero values be coerced to 1 when calculating + log-likelihood. """ - def __init__(self, name, statistic_config: confuse.ConfigView): + def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: + """ + Create an `Statistic` instance from a confuse config view. + + Args: + name: A human readable name for the statistic, mostly used for error + messages. + statistic_config: A confuse configuration view object describing an output + statistic. + + Raises: + ValueError: If an unsupported regularization name is provided via the + `statistic_config` arg. Currently only 'forecast' and 'allsubpop' are + supported. + """ self.sim_var = statistic_config["sim_var"].as_str() self.data_var = statistic_config["data_var"].as_str() self.name = name @@ -32,7 +72,7 @@ def __init__(self, name, statistic_config: confuse.ConfigView): if statistic_config["regularize"].exists(): for reg_config in statistic_config["regularize"]: # Iterate over the list reg_name = reg_config["name"].get() - reg_func = getattr(self, f"_{reg_name}_regularize") + reg_func = getattr(self, f"_{reg_name}_regularize", None) if reg_func is None: raise ValueError(f"Unsupported regularization: {reg_name}") self.regularizations.append((reg_func, reg_config.get())) @@ -47,12 +87,16 @@ def __init__(self, name, statistic_config: confuse.ConfigView): self.resample_aggregator = "" if resample_config["aggregator"].exists(): self.resample_aggregator_name = resample_config["aggregator"].get() - self.resample_skipna = False # TODO - if resample_config["aggregator"].exists() and resample_config["skipna"].exists(): + self.resample_skipna = False + if ( + resample_config["aggregator"].exists() + and resample_config["skipna"].exists() + ): self.resample_skipna = resample_config["skipna"].get() self.scale = False if statistic_config["scale"].exists(): + self.scale = True self.scale_func = getattr(np, statistic_config["scale"].get()) self.dist = statistic_config["likelihood"]["dist"].get() @@ -62,49 +106,138 @@ def __init__(self, name, statistic_config: confuse.ConfigView): self.params = {} self.zero_to_one = False - # TODO: this should be set_zeros_to and only do it for the probabilily if statistic_config["zero_to_one"].exists(): self.zero_to_one = statistic_config["zero_to_one"].get() - def _forecast_regularize(self, model_data, gt_data, **kwargs): - # scale the data so that the lastest X items are more important + def __str__(self) -> str: + return ( + f"{self.name}: {self.dist} between {self.sim_var} " + f"(sim) and {self.data_var} (data)." + ) + + def __repr__(self) -> str: + return f"A Statistic(): {self.__str__()}" + + def _forecast_regularize( + self, + model_data: xr.DataArray, + gt_data: xr.DataArray, + **kwargs: dict[str, int | float], + ) -> float: + """ + Regularization function to add weight to more recent forecasts. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + **kwargs: Optional keyword arguments that influence regularization. + Currently uses `last_n` for the number of observations to up weight and + `mult` for the coefficient of the regularization value. + + Returns: + The log-likelihood of the `last_n` observation up weighted by a factor of + `mult`. + """ + # scale the data so that the latest X items are more important last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) - last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None))) + last_n_llik = self.llik( + model_data.isel(date=slice(-last_n, None)), + gt_data.isel(date=slice(-last_n, None)), + ) return mult * last_n_llik.sum().sum().values - def _allsubpop_regularize(self, model_data, gt_data, **kwargs): - """add a regularization term that is the sum of all subpopulations""" + def _allsubpop_regularize( + self, + model_data: xr.DataArray, + gt_data: xr.DataArray, + **kwargs: dict[str, int | float], + ) -> float: + """ + Regularization function to add the sum of all subpopulations. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + **kwargs: Optional keyword arguments that influence regularization. + Currently uses `mult` for the coefficient of the regularization value. + + Returns: + The sum of the subpopulations multiplied by `mult`. + """ mult = kwargs.get("mult", 1) llik_total = self.llik(model_data.sum("subpop"), gt_data.sum("subpop")) return mult * llik_total.sum().sum().values - def __str__(self) -> str: - return f"{self.name}: {self.dist} between {self.sim_var} (sim) and {self.data_var} (data)." + def apply_resample(self, data: xr.DataArray) -> xr.DataArray: + """ + Resample a data set to the given frequency using the specified aggregation. - def __repr__(self) -> str: - return f"A Statistic(): {self.__str__()}" + Args: + data: An xarray dataset with "date" and "subpop" dimensions. - def apply_resample(self, data): + Returns: + A resample dataset with similar dimensions to `data`. + """ if self.resample: - aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name) + aggregator_method = getattr( + data.resample(date=self.resample_freq), self.resample_aggregator_name + ) return aggregator_method(skipna=self.resample_skipna) else: return data - def apply_scale(self, data): + def apply_scale(self, data: xr.DataArray) -> xr.DataArray: + """ + Scale a data set using the specified scaling function. + + Args: + data: An xarray dataset with "date" and "subpop" dimensions. + + Returns: + An xarray dataset of the same shape and dimensions as `data` with the + `scale_func` attribute applied. + """ if self.scale: return self.scale_func(data) else: return data - def apply_transforms(self, data): + def apply_transforms(self, data: xr.DataArray): + """ + Convenient wrapper for resampling and scaling a data set. + + The resampling is applied *before* scaling which can affect the log-likelihood. + + Args: + data: An xarray dataset with "date" and "subpop" dimensions. + + Returns: + An scaled and resampled dataset with similar dimensions to `data`. + """ data_scaled_resampled = self.apply_scale(self.apply_resample(data)) return data_scaled_resampled - def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): + def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: + """ + Compute the log-likelihood of observing the ground truth given model output. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + + Returns: + The log-likelihood of observing `gt_data` from the model `model_data` as an + xarray DataArray with a "subpop" dimension. + """ dist_map = { "pois": scipy.stats.poisson.logpmf, "norm": lambda x, loc, scale: scipy.stats.norm.logpdf( @@ -112,8 +245,10 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): ), # wrong: "norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf( x, loc=loc, scale=scale * loc.where(loc > 5, 5) - ), # TODO: check, that it's really the loc - "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(x, n=self.params.get("n"), p=model_data), + ), + "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf( + x, n=self.params.get("n"), p=model_data + ), "rmse": lambda x, y: -np.log(np.nansum(np.sqrt((x - y) ** 2))), "absolute_error": lambda x, y: -np.log(np.nansum(np.abs(x - y))), } @@ -133,20 +268,44 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): likelihood = xr.DataArray(likelihood, coords=gt_data.coords, dims=gt_data.dims) - # TODO: check the order of the arguments return likelihood - def compute_logloss(self, model_data, gt_data): + def compute_logloss( + self, model_data: xr.Dataset, gt_data: xr.Dataset + ) -> tuple[xr.DataArray, float]: + """ + Compute the logistic loss of observing the ground truth given model output. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + + Returns: + The logistic loss of observing `gt_data` from the model `model_data` + decomposed into the log-likelihood along the "subpop" dimension and + regularizations. + + Raises: + ValueError: If `model_data` and `gt_data` do not have the same shape. + """ model_data = self.apply_transforms(model_data[self.sim_var]) gt_data = self.apply_transforms(gt_data[self.data_var]) if not model_data.shape == gt_data.shape: raise ValueError( - f"{self.name} Statistic error: data and groundtruth do not have the same shape: model_data.shape={model_data.shape} != gt_data.shape={gt_data.shape}" + ( + f"{self.name} Statistic error: data and groundtruth do not have " + f"the same shape: model_data.shape={model_data.shape} != " + f"gt_data.shape={gt_data.shape}" + ) ) - regularization = 0 + regularization = 0.0 for reg_func, reg_config in self.regularizations: - regularization += reg_func(model_data=model_data, gt_data=gt_data, **reg_config) # Pass config parameters + regularization += reg_func( + model_data=model_data, gt_data=gt_data, **reg_config + ) # Pass config parameters return self.llik(model_data, gt_data).sum("date"), regularization diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py new file mode 100644 index 000000000..e18861e9e --- /dev/null +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -0,0 +1,561 @@ +from datetime import date +from itertools import product +from typing import Any, Callable + +import confuse +import numpy as np +import pandas as pd +import pytest +import scipy +import xarray as xr + +from gempyor.statistics import Statistic +from gempyor.testing import create_confuse_configview_from_dict + + +class MockStatisticInput: + def __init__( + self, + name: str, + config: dict[str, Any], + model_data: xr.DataArray | None = None, + gt_data: xr.DataArray | None = None, + ) -> None: + self.name = name + self.config = config + self.model_data = model_data + self.gt_data = gt_data + self._confuse_subview = None + + def create_confuse_subview(self) -> confuse.Subview: + if self._confuse_subview is None: + self._confuse_subview = create_confuse_configview_from_dict( + self.config, name=self.name + ) + return self._confuse_subview + + def create_statistic_instance(self) -> Statistic: + return Statistic(self.name, self.create_confuse_subview()) + + +def invalid_regularization_factory() -> MockStatisticInput: + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "rmse"}, + "regularize": [{"name": "forecast"}, {"name": "invalid"}], + }, + ) + + +def invalid_misshaped_data_factory() -> MockStatisticInput: + model_data = xr.Dataset( + data_vars={"incidH": (["date", "subpop"], np.random.randn(10, 3))}, + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + }, + ) + gt_data = xr.Dataset( + data_vars={"incidH": (["date", "subpop"], np.random.randn(11, 2))}, + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 11)), + "subpop": ["02", "03"], + }, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "norm", "params": {"scale": 2.0}}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + +def simple_valid_factory() -> MockStatisticInput: + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "norm", "params": {"scale": 2.0}}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + +def simple_valid_resample_factory() -> MockStatisticInput: + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "rmse"}, + "resample": {"freq": "MS", "aggregator": "sum"}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + +def simple_valid_scale_factory() -> MockStatisticInput: + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "rmse"}, + "scale": "exp", + }, + model_data=model_data, + gt_data=gt_data, + ) + + +def simple_valid_resample_and_scale_factory() -> MockStatisticInput: + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), + }, + coords=data_coords, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidD", + "data_var": "incidD", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "rmse"}, + "resample": {"freq": "W", "aggregator": "max"}, + "scale": "sin", + }, + model_data=model_data, + gt_data=gt_data, + ) + + +all_valid_factories = [ + (simple_valid_factory), + (simple_valid_resample_factory), + (simple_valid_resample_factory), + (simple_valid_resample_and_scale_factory), +] + + +class TestStatistic: + @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) + def test_unsupported_regularizations_value_error( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + mock_inputs = factory() + unsupported_name = next( + reg_name + for reg_name in [ + reg["name"] for reg in mock_inputs.config.get("regularize", []) + ] + if reg_name not in ["forecast", "allsubpop"] + ) + with pytest.raises( + ValueError, match=rf"^Unsupported regularization\: {unsupported_name}$" + ): + mock_inputs.create_statistic_instance() + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_statistic_instance_attributes( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # `data_var` attribute + assert statistic.data_var == mock_inputs.config["data_var"] + + # `dist` attribute + assert statistic.dist == mock_inputs.config["likelihood"]["dist"] + + # `name` attribute + assert statistic.name == mock_inputs.name + + # `params` attribute + assert statistic.params == mock_inputs.config["likelihood"].get("params", {}) + + # `regularizations` attribute + assert statistic.regularizations == [ + (r["name"], r) for r in mock_inputs.config.get("regularize", []) + ] + + # `resample` attribute + resample_config = mock_inputs.config.get("resample", {}) + assert statistic.resample == (resample_config != {}) + + if resample_config: + # `resample_aggregator_name` attribute + assert statistic.resample_aggregator_name == resample_config.get( + "aggregator", "" + ) + + # `resample_freq` attribute + assert statistic.resample_freq == resample_config.get("freq", "") + + # `resample_skipna` attribute + assert ( + statistic.resample_skipna == resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + + # `scale` attribute + assert statistic.scale == (mock_inputs.config.get("scale") is not None) + + # `scale_func` attribute + if (scale_func := mock_inputs.config.get("scale")) is not None: + assert statistic.scale_func == getattr(np, scale_func) + + # `sim_var` attribute + assert statistic.sim_var == mock_inputs.config["sim_var"] + + # `zero_to_one` attribute + assert statistic.zero_to_one == mock_inputs.config.get("zero_to_one", False) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_statistic_str_and_repr( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + statistic_str = ( + f"{mock_inputs.name}: {mock_inputs.config['likelihood']['dist']} between " + f"{mock_inputs.config['sim_var']} (sim) and " + f"{mock_inputs.config['data_var']} (data)." + ) + assert str(statistic) == statistic_str + assert repr(statistic) == f"A Statistic(): {statistic_str}" + + @pytest.mark.parametrize("factory,last_n,mult", [(simple_valid_factory, 4, 2.0)]) + def test_forecast_regularize( + self, factory: Callable[[], MockStatisticInput], last_n: int, mult: int | float + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + forecast_regularization = statistic._forecast_regularize( + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + last_n=last_n, + mult=mult, + ) + assert isinstance(forecast_regularization, float) + + @pytest.mark.parametrize("factory,mult", [(simple_valid_factory, 2.0)]) + def test_allsubpop_regularize( + self, factory: Callable[[], MockStatisticInput], mult: int | float + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + forecast_regularization = statistic._allsubpop_regularize( + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + mult=mult, + ) + assert isinstance(forecast_regularization, float) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + resampled_data = statistic.apply_resample( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + if resample_config := mock_inputs.config.get("resample", {}): + # Resample config + expected_resampled_data = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].resample(date=resample_config.get("freq", "")) + aggregation_func = getattr( + expected_resampled_data, resample_config.get("aggregator", "") + ) + expected_resampled_data = aggregation_func( + skipna=( + resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + ) + assert resampled_data.identical(expected_resampled_data) + else: + # No resample config, `apply_resample` returns our input + assert resampled_data.identical( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + scaled_data = statistic.apply_scale( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + if (scale_func := mock_inputs.config.get("scale")) is not None: + # Scale config + expected_scaled_data = getattr(np, scale_func)( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + assert scaled_data.identical(expected_scaled_data) + else: + # No scale config, `apply_scale` is a no-op + assert scaled_data.identical( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + transformed_data = statistic.apply_transforms( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + expected_transformed_data = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].copy() + if resample_config := mock_inputs.config.get("resample", {}): + # Resample config + expected_transformed_data = expected_transformed_data.resample( + date=resample_config.get("freq", "") + ) + aggregation_func = getattr( + expected_transformed_data, resample_config.get("aggregator", "") + ) + expected_transformed_data = aggregation_func( + skipna=( + resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + ) + if (scale_func := mock_inputs.config.get("scale")) is not None: + # Scale config + expected_transformed_data = getattr(np, scale_func)( + expected_transformed_data + ) + assert transformed_data.identical(expected_transformed_data) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + log_likelihood = statistic.llik( + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + ) + + assert isinstance(log_likelihood, xr.DataArray) + assert ( + log_likelihood.dims + == mock_inputs.gt_data[mock_inputs.config["data_var"]].dims + ) + assert log_likelihood.coords.identical( + mock_inputs.gt_data[mock_inputs.config["data_var"]].coords + ) + dist_name = mock_inputs.config["likelihood"]["dist"] + if dist_name in {"absolute_error", "rmse"}: + # MAE produces a single repeated number + assert np.allclose( + log_likelihood.values, + -np.log( + np.nansum( + np.abs( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + - mock_inputs.gt_data[mock_inputs.config["data_var"]] + ) + ) + ), + ) + elif dist_name == "pois": + assert np.allclose( + log_likelihood.values, + scipy.stats.poisson.logpmf( + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + mock_inputs.model_data[mock_inputs.config["data_var"]].values, + ), + ) + elif dist_name == {"norm", "norm_cov"}: + scale = mock_inputs.config["likelihood"]["params"]["scale"] + if dist_name == "norm_cov": + scale *= mock_inputs.model_data[mock_inputs.config["sim_var"]].where( + mock_inputs.model_data[mock_inputs.config["sim_var"]] > 5, 5 + ) + assert np.allclose( + log_likelihood.values, + scipy.stats.norm.logpdf( + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + mock_inputs.model_data[mock_inputs.config["sim_var"]].values, + scale=scale, + ), + ) + elif dist_name == "nbinom": + assert np.allclose( + log_likelihood.values, + scipy.stats.nbinom.logpmf( + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + n=mock_inputs.config["likelihood"]["params"]["n"], + p=mock_inputs.model_data[mock_inputs.config["sim_var"]].values, + ), + ) + + @pytest.mark.parametrize("factory", [(invalid_misshaped_data_factory)]) + def test_compute_logloss_data_misshape_value_error( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + model_rows, model_cols = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].shape + gt_rows, gt_cols = mock_inputs.gt_data[mock_inputs.config["data_var"]].shape + expected_match = ( + rf"^{mock_inputs.name} Statistic error\: data and groundtruth do not have " + rf"the same shape\: model\_data\.shape\=\({model_rows}\, {model_cols}\) " + rf"\!\= gt\_data\.shape\=\({gt_rows}\, {gt_cols}\)$" + ) + with pytest.raises(ValueError, match=expected_match): + statistic.compute_logloss(mock_inputs.model_data, mock_inputs.gt_data) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_compute_logloss(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + log_likelihood, regularization = statistic.compute_logloss( + mock_inputs.model_data, mock_inputs.gt_data + ) + regularization_config = mock_inputs.config.get("regularize", []) + + # Assertions on log_likelihood + assert isinstance(log_likelihood, xr.DataArray) + assert log_likelihood.coords.identical( + xr.Coordinates(coords={"subpop": mock_inputs.gt_data.coords.get("subpop")}) + ) + + # Assertions on regularization + assert isinstance(regularization, float) + if regularization_config: + # Regularizations on logistic loss + assert regularization != 0.0 + else: + # No regularizations on logistic loss + assert regularization == 0.0