From c2f423e737492ae1781653abcd7b828cd1851d71 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:20:07 -0400 Subject: [PATCH 01/18] Draft documentation for `gempyor.statistics` * Wrote draft documentation for the `statistics` module in Google style guide. --- .../gempyor_pkg/src/gempyor/statistics.py | 157 +++++++++++++++--- 1 file changed, 135 insertions(+), 22 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index ea2cc72a3..7591a91fc 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -1,29 +1,58 @@ -import xarray as xr -import pandas as pd -import numpy as np +""" +Abstractions for interacting with output statistic configurations. + +This module provides the `Statistic` class which represents a entry in the inference -> +statistics section. +""" + +__all__ = ["Statistic"] + + import confuse +import numpy as np import scipy.stats +import xarray as xr class Statistic: """ - A statistic is a function that takes two time series and returns a scalar value. - It applies resample, scale, and regularization to the data before computing the statistic's log-loss. - Configuration: - - sim_var: the variable in the simulation data - - data_var: the variable in the ground truth data - - resample: resample the data before computing the statistic - - freq: the frequency to resample the data to - - aggregator: the aggregation function to use - - skipna: whether to skip NA values - - regularize: apply a regularization term to the data before computing the statistic - - # SkipNA is False by default, which results in NA values broadcasting when resampling (e.g a NA withing a sum makes the whole sum a NA) - # if True, then NA are replaced with 0 (for sum), 1 for product, ... - # In doubt, plot stat.plot_transformed() to see the effect of the resampling + Encapsulates logic for representing/implementing output statistic configurations. + + A statistic is a function that takes two time series and returns a scalar value. It + applies resample, scale, and regularization to the data before computing the + statistic's log-loss. + + Attributes: + data_var: The variable in the ground truth data. + dist: The name of the distribution to use for calculating log-likelihood. + name: The human readable name for the statistic given during instantiation. + params: Distribution parameters used in the log-likelihood calculation and + dependent on `dist`. + regularizations: Regularization functions that are added to the log loss of this + statistic. + resample: If the data should be resampled before computing the statistic. + resample_aggregator_name: The name of the aggregation function to use. + resample_freq: The frequency to resample the data to if the `resample` attribute + is `True`. + resample_skipna: If NAs should be skipped when aggregating. `False` by default. + scale: If the data should be rescaled before computing the statistic. + scale_func: The function to use when rescaling the data. Can be any function + exported by `numpy`. + sim_var: The variable in the simulation data. + zero_to_one: Should non-zero values be coerced to 1 when calculating + log-likelihood. """ - - def __init__(self, name, statistic_config: confuse.ConfigView): + + def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: + """ + Create an `Statistic` instance from a confuse config view. + + Args: + name: A human readable name for the statistic, mostly used for error + messages. + statistic_config: A confuse configuration view object describing an output + statistic. + """ self.sim_var = statistic_config["sim_var"].as_str() self.data_var = statistic_config["data_var"].as_str() self.name = name @@ -62,12 +91,28 @@ def __init__(self, name, statistic_config: confuse.ConfigView): self.params = {} self.zero_to_one = False - # TODO: this should be set_zeros_to and only do it for the probabilily + # TODO: this should be set_zeros_to and only do it for the probability if statistic_config["zero_to_one"].exists(): self.zero_to_one = statistic_config["zero_to_one"].get() def _forecast_regularize(self, model_data, gt_data, **kwargs): - # scale the data so that the lastest X items are more important + """ + Regularization function to add weight to more recent forecasts. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + **kwargs: Optional keyword arguments that influence regularization. + Currently uses `last_n` for the number of observations to up weight and + `mult` for the coefficient of the regularization value. + + Returns: + The log-likelihood of the `last_n` observation up weighted by a factor of + `mult`. + """ + # scale the data so that the latest X items are more important last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) @@ -76,7 +121,20 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs): return mult * last_n_llik.sum().sum().values def _allsubpop_regularize(self, model_data, gt_data, **kwargs): - """add a regularization term that is the sum of all subpopulations""" + """ + Regularization function to add the sum of all subpopulations. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + **kwargs: Optional keyword arguments that influence regularization. + Currently uses `mult` for the coefficient of the regularization value. + + Returns: + The sum of the subpopulations multiplied by `mult`. + """ mult = kwargs.get("mult", 1) llik_total = self.llik(model_data.sum("subpop"), gt_data.sum("subpop")) return mult * llik_total.sum().sum().values @@ -88,6 +146,15 @@ def __repr__(self) -> str: return f"A Statistic(): {self.__str__()}" def apply_resample(self, data): + """ + Resample a data set to the given frequency using the specified aggregation. + + Args: + data: An xarray dataset with "date" and "subpop" dimensions. + + Returns: + A resample dataset with similar dimensions to `data`. + """ if self.resample: aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name) return aggregator_method(skipna=self.resample_skipna) @@ -95,16 +162,49 @@ def apply_resample(self, data): return data def apply_scale(self, data): + """ + Scale a data set using the specified scaling function. + + Args: + data: An xarray dataset with "date" and "subpop" dimensions. + + Returns: + An xarray dataset of the same shape and dimensions as `data` with the + `scale_func` attribute applied. + """ if self.scale: return self.scale_func(data) else: return data def apply_transforms(self, data): + """ + Convenient wrapper for resampling and scaling a data set. + + The resampling is applied *before* scaling which can affect the log-likelihood. + + Args: + data: An xarray dataset with "date" and "subpop" dimensions. + + Returns: + An scaled and resampled dataset with similar dimensions to `data`. + """ data_scaled_resampled = self.apply_scale(self.apply_resample(data)) return data_scaled_resampled def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): + """ + Compute the log-likelihood of observing the ground truth given model output. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + + Returns: + The log-likelihood of observing `gt_data` from the model `model_data`. + """ dist_map = { "pois": scipy.stats.poisson.logpmf, "norm": lambda x, loc, scale: scipy.stats.norm.logpdf( @@ -137,6 +237,19 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): return likelihood def compute_logloss(self, model_data, gt_data): + """ + Compute the logistic loss of observing the ground truth given model output. + + Args: + model_data: An xarray Dataset of the model data with date and subpop + dimensions. + gt_data: An xarray Dataset of the ground truth data with date and subpop + dimensions. + + Returns: + The logistic loss of observing `gt_data` from the model `model_data` + decomposed into the log-likelihood and regularizations. + """ model_data = self.apply_transforms(model_data[self.sim_var]) gt_data = self.apply_transforms(gt_data[self.data_var]) From cc3a691e55eeba0ef78088a128853997933aecee Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:52:34 -0400 Subject: [PATCH 02/18] Applied black formatter --- .../gempyor_pkg/src/gempyor/statistics.py | 84 +++++++++++-------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 7591a91fc..2ed7d496a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -17,11 +17,11 @@ class Statistic: """ Encapsulates logic for representing/implementing output statistic configurations. - - A statistic is a function that takes two time series and returns a scalar value. It - applies resample, scale, and regularization to the data before computing the + + A statistic is a function that takes two time series and returns a scalar value. It + applies resample, scale, and regularization to the data before computing the statistic's log-loss. - + Attributes: data_var: The variable in the ground truth data. dist: The name of the distribution to use for calculating log-likelihood. @@ -36,19 +36,19 @@ class Statistic: is `True`. resample_skipna: If NAs should be skipped when aggregating. `False` by default. scale: If the data should be rescaled before computing the statistic. - scale_func: The function to use when rescaling the data. Can be any function + scale_func: The function to use when rescaling the data. Can be any function exported by `numpy`. sim_var: The variable in the simulation data. - zero_to_one: Should non-zero values be coerced to 1 when calculating + zero_to_one: Should non-zero values be coerced to 1 when calculating log-likelihood. """ - + def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: """ Create an `Statistic` instance from a confuse config view. - + Args: - name: A human readable name for the statistic, mostly used for error + name: A human readable name for the statistic, mostly used for error messages. statistic_config: A confuse configuration view object describing an output statistic. @@ -77,7 +77,10 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: if resample_config["aggregator"].exists(): self.resample_aggregator_name = resample_config["aggregator"].get() self.resample_skipna = False # TODO - if resample_config["aggregator"].exists() and resample_config["skipna"].exists(): + if ( + resample_config["aggregator"].exists() + and resample_config["skipna"].exists() + ): self.resample_skipna = resample_config["skipna"].get() self.scale = False @@ -98,42 +101,45 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: def _forecast_regularize(self, model_data, gt_data, **kwargs): """ Regularization function to add weight to more recent forecasts. - + Args: model_data: An xarray Dataset of the model data with date and subpop dimensions. gt_data: An xarray Dataset of the ground truth data with date and subpop dimensions. - **kwargs: Optional keyword arguments that influence regularization. + **kwargs: Optional keyword arguments that influence regularization. Currently uses `last_n` for the number of observations to up weight and - `mult` for the coefficient of the regularization value. - + `mult` for the coefficient of the regularization value. + Returns: - The log-likelihood of the `last_n` observation up weighted by a factor of + The log-likelihood of the `last_n` observation up weighted by a factor of `mult`. """ # scale the data so that the latest X items are more important last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) - last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None))) + last_n_llik = self.llik( + model_data.isel(date=slice(-last_n, None)), + gt_data.isel(date=slice(-last_n, None)), + ) return mult * last_n_llik.sum().sum().values def _allsubpop_regularize(self, model_data, gt_data, **kwargs): """ Regularization function to add the sum of all subpopulations. - + Args: model_data: An xarray Dataset of the model data with date and subpop dimensions. gt_data: An xarray Dataset of the ground truth data with date and subpop dimensions. - **kwargs: Optional keyword arguments that influence regularization. + **kwargs: Optional keyword arguments that influence regularization. Currently uses `mult` for the coefficient of the regularization value. - + Returns: - The sum of the subpopulations multiplied by `mult`. + The sum of the subpopulations multiplied by `mult`. """ mult = kwargs.get("mult", 1) llik_total = self.llik(model_data.sum("subpop"), gt_data.sum("subpop")) @@ -148,15 +154,17 @@ def __repr__(self) -> str: def apply_resample(self, data): """ Resample a data set to the given frequency using the specified aggregation. - + Args: data: An xarray dataset with "date" and "subpop" dimensions. - + Returns: A resample dataset with similar dimensions to `data`. """ if self.resample: - aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name) + aggregator_method = getattr( + data.resample(date=self.resample_freq), self.resample_aggregator_name + ) return aggregator_method(skipna=self.resample_skipna) else: return data @@ -164,12 +172,12 @@ def apply_resample(self, data): def apply_scale(self, data): """ Scale a data set using the specified scaling function. - + Args: data: An xarray dataset with "date" and "subpop" dimensions. - + Returns: - An xarray dataset of the same shape and dimensions as `data` with the + An xarray dataset of the same shape and dimensions as `data` with the `scale_func` attribute applied. """ if self.scale: @@ -180,12 +188,12 @@ def apply_scale(self, data): def apply_transforms(self, data): """ Convenient wrapper for resampling and scaling a data set. - + The resampling is applied *before* scaling which can affect the log-likelihood. - + Args: data: An xarray dataset with "date" and "subpop" dimensions. - + Returns: An scaled and resampled dataset with similar dimensions to `data`. """ @@ -195,13 +203,13 @@ def apply_transforms(self, data): def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): """ Compute the log-likelihood of observing the ground truth given model output. - + Args: model_data: An xarray Dataset of the model data with date and subpop dimensions. gt_data: An xarray Dataset of the ground truth data with date and subpop dimensions. - + Returns: The log-likelihood of observing `gt_data` from the model `model_data`. """ @@ -213,7 +221,9 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): "norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf( x, loc=loc, scale=scale * loc.where(loc > 5, 5) ), # TODO: check, that it's really the loc - "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(x, n=self.params.get("n"), p=model_data), + "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf( + x, n=self.params.get("n"), p=model_data + ), "rmse": lambda x, y: -np.log(np.nansum(np.sqrt((x - y) ** 2))), "absolute_error": lambda x, y: -np.log(np.nansum(np.abs(x - y))), } @@ -239,15 +249,15 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): def compute_logloss(self, model_data, gt_data): """ Compute the logistic loss of observing the ground truth given model output. - + Args: model_data: An xarray Dataset of the model data with date and subpop dimensions. gt_data: An xarray Dataset of the ground truth data with date and subpop dimensions. - + Returns: - The logistic loss of observing `gt_data` from the model `model_data` + The logistic loss of observing `gt_data` from the model `model_data` decomposed into the log-likelihood and regularizations. """ model_data = self.apply_transforms(model_data[self.sim_var]) @@ -260,6 +270,8 @@ def compute_logloss(self, model_data, gt_data): regularization = 0 for reg_func, reg_config in self.regularizations: - regularization += reg_func(model_data=model_data, gt_data=gt_data, **reg_config) # Pass config parameters + regularization += reg_func( + model_data=model_data, gt_data=gt_data, **reg_config + ) # Pass config parameters return self.llik(model_data, gt_data).sum("date"), regularization From 867924a29a1c6701b8b9cbdcb127a6a1b924351d Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 20 Aug 2024 10:24:05 -0400 Subject: [PATCH 03/18] Type annotations, black formatter * Added missing type annotations and corrected already existing ones. * Applied black formatter to the file, including manually correcting some line-length issues. * Rearranged dunder methods. --- .../gempyor_pkg/src/gempyor/statistics.py | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 2ed7d496a..1f52516e4 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -98,7 +98,21 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: if statistic_config["zero_to_one"].exists(): self.zero_to_one = statistic_config["zero_to_one"].get() - def _forecast_regularize(self, model_data, gt_data, **kwargs): + def __str__(self) -> str: + return ( + f"{self.name}: {self.dist} between {self.sim_var} " + f"(sim) and {self.data_var} (data)." + ) + + def __repr__(self) -> str: + return f"A Statistic(): {self.__str__()}" + + def _forecast_regularize( + self, + model_data: xr.DataArray, + gt_data: xr.DataArray, + **kwargs: dict[str, int | float], + ) -> float: """ Regularization function to add weight to more recent forecasts. @@ -126,7 +140,12 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs): return mult * last_n_llik.sum().sum().values - def _allsubpop_regularize(self, model_data, gt_data, **kwargs): + def _allsubpop_regularize( + self, + model_data: xr.DataArray, + gt_data: xr.DataArray, + **kwargs: dict[str, int | float], + ) -> float: """ Regularization function to add the sum of all subpopulations. @@ -145,13 +164,7 @@ def _allsubpop_regularize(self, model_data, gt_data, **kwargs): llik_total = self.llik(model_data.sum("subpop"), gt_data.sum("subpop")) return mult * llik_total.sum().sum().values - def __str__(self) -> str: - return f"{self.name}: {self.dist} between {self.sim_var} (sim) and {self.data_var} (data)." - - def __repr__(self) -> str: - return f"A Statistic(): {self.__str__()}" - - def apply_resample(self, data): + def apply_resample(self, data: xr.DataArray) -> xr.DataArray: """ Resample a data set to the given frequency using the specified aggregation. @@ -169,7 +182,7 @@ def apply_resample(self, data): else: return data - def apply_scale(self, data): + def apply_scale(self, data: xr.DataArray) -> xr.DataArray: """ Scale a data set using the specified scaling function. @@ -185,7 +198,7 @@ def apply_scale(self, data): else: return data - def apply_transforms(self, data): + def apply_transforms(self, data: xr.DataArray): """ Convenient wrapper for resampling and scaling a data set. @@ -200,7 +213,7 @@ def apply_transforms(self, data): data_scaled_resampled = self.apply_scale(self.apply_resample(data)) return data_scaled_resampled - def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): + def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> float: """ Compute the log-likelihood of observing the ground truth given model output. @@ -246,7 +259,9 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): # TODO: check the order of the arguments return likelihood - def compute_logloss(self, model_data, gt_data): + def compute_logloss( + self, model_data: xr.DataArray, gt_data: xr.DataArray + ) -> tuple[float, float]: """ Compute the logistic loss of observing the ground truth given model output. @@ -265,7 +280,11 @@ def compute_logloss(self, model_data, gt_data): if not model_data.shape == gt_data.shape: raise ValueError( - f"{self.name} Statistic error: data and groundtruth do not have the same shape: model_data.shape={model_data.shape} != gt_data.shape={gt_data.shape}" + ( + f"{self.name} Statistic error: data and groundtruth do not have " + f"the same shape: model_data.shape={model_data.shape} != " + f"gt_data.shape={gt_data.shape}" + ) ) regularization = 0 From 8fb81cdce36118124ed3b8d5d5b2d21ab66067e5 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:22:34 -0400 Subject: [PATCH 04/18] Initial unit test infra for `Statistic` class * Created initial unit testing infrastructure for the `Statistic` class from `gempyor.statistics`, starting with invalid regularization name value error. * Added default to `getattr` call to make unsupported regularization value error reachable. Should obsolete with better documentation. --- .../gempyor_pkg/src/gempyor/statistics.py | 2 +- .../tests/statistics/test_statistic_class.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 1f52516e4..53e2241ca 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -61,7 +61,7 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: if statistic_config["regularize"].exists(): for reg_config in statistic_config["regularize"]: # Iterate over the list reg_name = reg_config["name"].get() - reg_func = getattr(self, f"_{reg_name}_regularize") + reg_func = getattr(self, f"_{reg_name}_regularize", None) if reg_func is None: raise ValueError(f"Unsupported regularization: {reg_name}") self.regularizations.append((reg_func, reg_config.get())) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py new file mode 100644 index 000000000..9347cd936 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -0,0 +1,72 @@ +import pathlib +from typing import Any, Callable + +import confuse +import pytest +import xarray as xr + +from gempyor.statistics import Statistic +from gempyor.testing import create_confuse_configview_from_dict + + +class MockStatisticInput: + def __init__( + self, + name: str, + config: dict[str, Any], + model_data: xr.DataArray | None = None, + gt_data: xr.DataArray | None = None, + ) -> None: + self.name = name + self.config = config + self.model_data = model_data + self.gt_data = gt_data + self._confuse_subview = None + + def create_confuse_subview(self) -> confuse.Subview: + if self._confuse_subview is None: + self._confuse_subview = create_confuse_configview_from_dict( + self.config, name=self.name + ) + return self._confuse_subview + + def create_statistic_instance(self) -> Statistic: + return Statistic(self.name, self.create_confuse_subview()) + + +def invalid_regularization_factory(tmp_path: pathlib.Path) -> MockStatisticInput: + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "aggregator": "sum", + "period": "1 months", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + "regularize": [{"name": "forecast"}, {"name": "invalid"}], + }, + ) + + +class TestStatistic: + @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) + def test_unsupported_regularizations_value_error( + self, + tmp_path: pathlib.Path, + factory: Callable[[pathlib.Path], MockStatisticInput], + ) -> None: + mock_inputs = factory(tmp_path) + unsupported_name = next( + reg_name + for reg_name in [ + reg["name"] for reg in mock_inputs.config.get("regularize", []) + ] + if reg_name not in ["forecast", "allsubpop"] + ) + with pytest.raises( + ValueError, match=rf"^Unsupported regularization\: {unsupported_name}$" + ): + mock_inputs.create_statistic_instance() From b27ec5269160d88afded4e6ac0bc3917ba2d0680 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:38:26 -0400 Subject: [PATCH 05/18] Add `Statistic` attributes test fixture * Added a test fixture to test the attributes of the `Statistic` class. * Removed unnecessary `tmp_path` pytest fixture dependency. * Improved documentation on the `Statistic` class' attributes and added a raises section for the constructor. --- .../gempyor_pkg/src/gempyor/statistics.py | 19 ++++- .../tests/statistics/test_statistic_class.py | 82 +++++++++++++++++-- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 53e2241ca..461b5dd41 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -31,14 +31,20 @@ class Statistic: regularizations: Regularization functions that are added to the log loss of this statistic. resample: If the data should be resampled before computing the statistic. - resample_aggregator_name: The name of the aggregation function to use. + Defaults to `False`. + resample_aggregator_name: The name of the aggregation function to use. This + attribute is not set when a "resample" section is not defined in the + `statistic_config` arg. resample_freq: The frequency to resample the data to if the `resample` attribute - is `True`. + is `True`. This attribute is not set when a "resample" section is not + defined in the `statistic_config` arg. resample_skipna: If NAs should be skipped when aggregating. `False` by default. + This attribute is not set when a "resample" section is not defined in the + `statistic_config` arg. scale: If the data should be rescaled before computing the statistic. scale_func: The function to use when rescaling the data. Can be any function - exported by `numpy`. - sim_var: The variable in the simulation data. + exported by `numpy`. This attribute is not set when a "scale" value is not + defined in the `statistic_config` arg. zero_to_one: Should non-zero values be coerced to 1 when calculating log-likelihood. """ @@ -52,6 +58,11 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: messages. statistic_config: A confuse configuration view object describing an output statistic. + + Raises: + ValueError: If an unsupported regularization name is provided via the + `statistic_config` arg. Currently only 'forecast' and 'allsubpop' are + supported. """ self.sim_var = statistic_config["sim_var"].as_str() self.data_var = statistic_config["data_var"].as_str() diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 9347cd936..94a129cf0 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -34,7 +34,7 @@ def create_statistic_instance(self) -> Statistic: return Statistic(self.name, self.create_confuse_subview()) -def invalid_regularization_factory(tmp_path: pathlib.Path) -> MockStatisticInput: +def invalid_regularization_factory() -> MockStatisticInput: return MockStatisticInput( "total_hospitalizations", { @@ -51,14 +51,28 @@ def invalid_regularization_factory(tmp_path: pathlib.Path) -> MockStatisticInput ) +def simple_valid_factory() -> MockStatisticInput: + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "aggregator": "sum", + "period": "1 months", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + }, + ) + + class TestStatistic: @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) def test_unsupported_regularizations_value_error( - self, - tmp_path: pathlib.Path, - factory: Callable[[pathlib.Path], MockStatisticInput], + self, factory: Callable[[], MockStatisticInput] ) -> None: - mock_inputs = factory(tmp_path) + mock_inputs = factory() unsupported_name = next( reg_name for reg_name in [ @@ -70,3 +84,61 @@ def test_unsupported_regularizations_value_error( ValueError, match=rf"^Unsupported regularization\: {unsupported_name}$" ): mock_inputs.create_statistic_instance() + + @pytest.mark.parametrize("factory", [(simple_valid_factory)]) + def test_statistic_instance_attributes( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # `data_var` attribute + assert statistic.data_var == mock_inputs.config["data_var"] + + # `dist` attribute + assert statistic.dist == mock_inputs.config["likelihood"]["dist"] + + # `name` attribute + assert statistic.name == mock_inputs.name + + # `params` attribute + assert statistic.params == mock_inputs.config["likelihood"].get("params", {}) + + # `regularizations` attribute + assert statistic.regularizations == [ + (r["name"], r) for r in mock_inputs.config.get("regularize", []) + ] + + # `resample` attribute + resample_config = mock_inputs.config.get("resample", {}) + assert statistic.resample == (resample_config != {}) + + if resample_config: + # `resample_aggregator_name` attribute + assert statistic.resample_aggregator_name == resample_config.get( + "aggregator", "" + ) + + # `resample_freq` attribute + assert statistic.resample_freq == resample_config.get("freq", "") + + # `resample_skipna` attribute + assert ( + statistic.resample_skipna == resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + + # `scale` attribute + assert statistic.scale == (mock_inputs.config.get("scale") is not None) + + # `scale_func` attribute + if scale_func := mock_inputs.config.get("scale") is not None: + assert statistic.scale_func == scale_func + + # `sim_var` attribute + assert statistic.sim_var == mock_inputs.config["sim_var"] + + # `zero_to_one` attribute + assert statistic.zero_to_one == mock_inputs.config.get("zero_to_one", False) From f1b559c59a98866d144fec93b233da651b92de0d Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:53:44 -0400 Subject: [PATCH 06/18] Add fixture for `str` and `repr` of `Statistic` Added a test fixture for the result of calling `str` and `repr` on an instance of the `Statistic` class. --- .../tests/statistics/test_statistic_class.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 94a129cf0..544d89f0d 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -142,3 +142,20 @@ def test_statistic_instance_attributes( # `zero_to_one` attribute assert statistic.zero_to_one == mock_inputs.config.get("zero_to_one", False) + + @pytest.mark.parametrize("factory", [(simple_valid_factory)]) + def test_statistic_str_and_repr( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + statistic_str = ( + f"{mock_inputs.name}: {mock_inputs.config['likelihood']['dist']} between " + f"{mock_inputs.config['sim_var']} (sim) and " + f"{mock_inputs.config['data_var']} (data)." + ) + assert str(statistic) == statistic_str + assert repr(statistic) == f"A Statistic(): {statistic_str}" From b8891be4f0e02539fe76559a61e00a0af214da28 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:48:59 -0400 Subject: [PATCH 07/18] Corrected `llik` return type hint The return of `Statistic.llik` is actually an xarray DataArray instead of a float, but summed along the date dimension. --- flepimop/gempyor_pkg/src/gempyor/statistics.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 461b5dd41..1659e6558 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -58,7 +58,7 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: messages. statistic_config: A confuse configuration view object describing an output statistic. - + Raises: ValueError: If an unsupported regularization name is provided via the `statistic_config` arg. Currently only 'forecast' and 'allsubpop' are @@ -224,7 +224,7 @@ def apply_transforms(self, data: xr.DataArray): data_scaled_resampled = self.apply_scale(self.apply_resample(data)) return data_scaled_resampled - def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> float: + def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: """ Compute the log-likelihood of observing the ground truth given model output. @@ -235,7 +235,8 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> float: dimensions. Returns: - The log-likelihood of observing `gt_data` from the model `model_data`. + The log-likelihood of observing `gt_data` from the model `model_data` as an + xarray DataArray with a "subpop" dimension. """ dist_map = { "pois": scipy.stats.poisson.logpmf, @@ -272,7 +273,7 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> float: def compute_logloss( self, model_data: xr.DataArray, gt_data: xr.DataArray - ) -> tuple[float, float]: + ) -> tuple[xr.DataArray, float]: """ Compute the logistic loss of observing the ground truth given model output. @@ -284,7 +285,8 @@ def compute_logloss( Returns: The logistic loss of observing `gt_data` from the model `model_data` - decomposed into the log-likelihood and regularizations. + decomposed into the log-likelihood along the "subpop" dimension and + regularizations. """ model_data = self.apply_transforms(model_data[self.sim_var]) gt_data = self.apply_transforms(gt_data[self.data_var]) From 239ced62e7a725f96919c9ace57e9b9a8c80ced7 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 09:31:07 -0400 Subject: [PATCH 08/18] Move TODO comments to GH-300 In particular see https://github.com/HopkinsIDD/flepiMoP/issues/300#issuecomment-2302065476. --- flepimop/gempyor_pkg/src/gempyor/statistics.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index 1659e6558..aed46357e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -87,7 +87,7 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: self.resample_aggregator = "" if resample_config["aggregator"].exists(): self.resample_aggregator_name = resample_config["aggregator"].get() - self.resample_skipna = False # TODO + self.resample_skipna = False if ( resample_config["aggregator"].exists() and resample_config["skipna"].exists() @@ -105,7 +105,6 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: self.params = {} self.zero_to_one = False - # TODO: this should be set_zeros_to and only do it for the probability if statistic_config["zero_to_one"].exists(): self.zero_to_one = statistic_config["zero_to_one"].get() @@ -245,7 +244,7 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: ), # wrong: "norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf( x, loc=loc, scale=scale * loc.where(loc > 5, 5) - ), # TODO: check, that it's really the loc + ), "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf( x, n=self.params.get("n"), p=model_data ), @@ -268,7 +267,6 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: likelihood = xr.DataArray(likelihood, coords=gt_data.coords, dims=gt_data.dims) - # TODO: check the order of the arguments return likelihood def compute_logloss( @@ -285,7 +283,7 @@ def compute_logloss( Returns: The logistic loss of observing `gt_data` from the model `model_data` - decomposed into the log-likelihood along the "subpop" dimension and + decomposed into the log-likelihood along the "subpop" dimension and regularizations. """ model_data = self.apply_transforms(model_data[self.sim_var]) From f5a9cb9c81439356cc7bcf14e474aa26fde75bca Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:46:30 -0400 Subject: [PATCH 09/18] Initial tests for Statistic regularization Added the initial unit tests for the regularization methods, `_forecast_regularize` and `_allsubpop_regularize`, of the `Statistic` class. The tests are general and do not make claims about correctness for now. --- .../tests/statistics/test_statistic_class.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 544d89f0d..ceb7ea87a 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -1,7 +1,9 @@ -import pathlib +from datetime import date from typing import Any, Callable import confuse +import numpy as np +import pandas as pd import pytest import xarray as xr @@ -52,6 +54,22 @@ def invalid_regularization_factory() -> MockStatisticInput: def simple_valid_factory() -> MockStatisticInput: + model_data = xr.DataArray( + data=np.random.randn(10, 3), + dims=("date", "subpop"), + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + }, + ) + gt_data = xr.DataArray( + data=np.random.randn(10, 3), + dims=("date", "subpop"), + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + }, + ) return MockStatisticInput( "total_hospitalizations", { @@ -64,6 +82,8 @@ def simple_valid_factory() -> MockStatisticInput: "add_one": True, "likelihood": {"dist": "pois"}, }, + model_data=model_data, + gt_data=gt_data, ) @@ -159,3 +179,31 @@ def test_statistic_str_and_repr( ) assert str(statistic) == statistic_str assert repr(statistic) == f"A Statistic(): {statistic_str}" + + @pytest.mark.parametrize("factory,last_n,mult", [(simple_valid_factory, 4, 2.0)]) + def test_forecast_regularize( + self, factory: Callable[[], MockStatisticInput], last_n: int, mult: int | float + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + forecast_regularization = statistic._forecast_regularize( + mock_inputs.model_data, mock_inputs.gt_data, last_n=last_n, mult=mult + ) + assert isinstance(forecast_regularization, float) + + @pytest.mark.parametrize("factory,mult", [(simple_valid_factory, 2.0)]) + def test_allsubpop_regularize( + self, factory: Callable[[], MockStatisticInput], mult: int | float + ) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + forecast_regularization = statistic._allsubpop_regularize( + mock_inputs.model_data, mock_inputs.gt_data, mult=mult + ) + assert isinstance(forecast_regularization, float) From 878c5392267aefffc58cde32d36ca4824f7ade81 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:46:44 -0400 Subject: [PATCH 10/18] Unit tests for `Statistic.apply_resample` Added unit tests for the `apply_resample` method, including creating a new factory, `simple_valid_resample_factory`, which hits the "resample config present" of the code path. --- .../tests/statistics/test_statistic_class.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index ceb7ea87a..3865a4657 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -87,6 +87,44 @@ def simple_valid_factory() -> MockStatisticInput: ) +def simple_valid_resample_factory() -> MockStatisticInput: + date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) + subpop_coords = ["01", "02", "03", "04"] + dim = (len(date_coords), len(subpop_coords)) + model_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + gt_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "aggregator": "sum", + "period": "1 months", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + "resample": {"freq": "MS", "aggregator": "sum"}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + class TestStatistic: @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) def test_unsupported_regularizations_value_error( @@ -207,3 +245,33 @@ def test_allsubpop_regularize( mock_inputs.model_data, mock_inputs.gt_data, mult=mult ) assert isinstance(forecast_regularization, float) + + @pytest.mark.parametrize( + "factory", [(simple_valid_factory), (simple_valid_resample_factory)] + ) + def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + resampled_data = statistic.apply_resample(mock_inputs.model_data) + if resample_config := mock_inputs.config.get("resample", {}): + # Resample config + expected_resampled_data = mock_inputs.model_data.resample( + date=resample_config.get("freq", "") + ) + aggregation_func = getattr( + expected_resampled_data, resample_config.get("aggregator", "") + ) + expected_resampled_data = aggregation_func( + skipna=( + resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + ) + assert resampled_data.identical(expected_resampled_data) + else: + # No resample config, `apply_resample` returns our input + assert resampled_data.identical(mock_inputs.model_data) From 31368be36cd5aeec56f73b3668349d98b2e09d79 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 13:35:18 -0400 Subject: [PATCH 11/18] Added unit tests for `Statistic.apply_scale` Added unit tests for `apply_scale` method including a new factory that produces an input set with a 'scale' config. Fixed a bug where the scale function was not applied even if provided. This is a *breaking* change, but doesn't affect currently existing test suite, need to see if this affects any currently existing config files. --- .../gempyor_pkg/src/gempyor/statistics.py | 1 + .../tests/statistics/test_statistic_class.py | 56 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index aed46357e..fd1dd1a43 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -96,6 +96,7 @@ def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None: self.scale = False if statistic_config["scale"].exists(): + self.scale = True self.scale_func = getattr(np, statistic_config["scale"].get()) self.dist = statistic_config["likelihood"]["dist"].get() diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 3865a4657..8f51d1716 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -125,6 +125,44 @@ def simple_valid_resample_factory() -> MockStatisticInput: ) +def simple_valid_scale_factory() -> MockStatisticInput: + date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) + subpop_coords = ["01", "02", "03", "04"] + dim = (len(date_coords), len(subpop_coords)) + model_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + gt_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "aggregator": "sum", + "period": "1 months", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + "scale": "exp", + }, + model_data=model_data, + gt_data=gt_data, + ) + + class TestStatistic: @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) def test_unsupported_regularizations_value_error( @@ -275,3 +313,21 @@ def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None else: # No resample config, `apply_resample` returns our input assert resampled_data.identical(mock_inputs.model_data) + + @pytest.mark.parametrize( + "factory", [(simple_valid_factory), (simple_valid_scale_factory)] + ) + def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + scaled_data = statistic.apply_scale(mock_inputs.model_data) + if (scale_func := mock_inputs.config.get("scale")) is not None: + # Scale config + expected_scaled_data = getattr(np, scale_func)(mock_inputs.model_data) + assert scaled_data.identical(expected_scaled_data) + else: + # No scale config, `apply_scale` is a no-op + assert scaled_data.identical(mock_inputs.model_data) From 84f36325b6639ce17d3e4223c7f7db34c53f275a Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 15:42:20 -0400 Subject: [PATCH 12/18] Added unit test for `Statistic.apply_transforms` Added unit tests for the `apply_transforms` method of `Statistic`, including making a new factory that includes both resampling and scaling configuration. --- .../tests/statistics/test_statistic_class.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 8f51d1716..041206f6f 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -163,6 +163,45 @@ def simple_valid_scale_factory() -> MockStatisticInput: ) +def simple_valid_resample_and_scale_factory() -> MockStatisticInput: + date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) + subpop_coords = ["01", "02", "03", "04"] + dim = (len(date_coords), len(subpop_coords)) + model_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + gt_data = xr.DataArray( + data=np.random.randn(*dim), + dims=("date", "subpop"), + coords={ + "date": date_coords, + "subpop": subpop_coords, + }, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "aggregator": "sum", + "period": "1 months", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "pois"}, + "resample": {"freq": "W", "aggregator": "max"}, + "scale": "sin", + }, + model_data=model_data, + gt_data=gt_data, + ) + + class TestStatistic: @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) def test_unsupported_regularizations_value_error( @@ -331,3 +370,42 @@ def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: else: # No scale config, `apply_scale` is a no-op assert scaled_data.identical(mock_inputs.model_data) + + @pytest.mark.parametrize( + "factory", + [ + (simple_valid_factory), + (simple_valid_resample_factory), + (simple_valid_scale_factory), + (simple_valid_resample_and_scale_factory), + ], + ) + def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + transformed_data = statistic.apply_transforms(mock_inputs.model_data) + expected_transformed_data = mock_inputs.model_data.copy() + if resample_config := mock_inputs.config.get("resample", {}): + # Resample config + expected_transformed_data = expected_transformed_data.resample( + date=resample_config.get("freq", "") + ) + aggregation_func = getattr( + expected_transformed_data, resample_config.get("aggregator", "") + ) + expected_transformed_data = aggregation_func( + skipna=( + resample_config.get("skipna", False) + if resample_config.get("aggregator") is not None + else False + ) + ) + if (scale_func := mock_inputs.config.get("scale")) is not None: + # Scale config + expected_transformed_data = getattr(np, scale_func)( + expected_transformed_data + ) + assert transformed_data.identical(expected_transformed_data) From a518f36b13b0858a7aa205c6c13739eae4c24c15 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 21 Aug 2024 15:54:00 -0400 Subject: [PATCH 13/18] Consolidate valid factories into global var Created global `all_valid_factories` that can be passed directly to the `pytest.mark.parametrize` decorator to test methods of the `Statistic` class against many configurations. --- .../tests/statistics/test_statistic_class.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 041206f6f..e466af0a5 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -202,6 +202,14 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput: ) +all_valid_factories = [ + (simple_valid_factory), + (simple_valid_resample_factory), + (simple_valid_resample_factory), + (simple_valid_resample_and_scale_factory), +] + + class TestStatistic: @pytest.mark.parametrize("factory", [(invalid_regularization_factory)]) def test_unsupported_regularizations_value_error( @@ -220,7 +228,7 @@ def test_unsupported_regularizations_value_error( ): mock_inputs.create_statistic_instance() - @pytest.mark.parametrize("factory", [(simple_valid_factory)]) + @pytest.mark.parametrize("factory", all_valid_factories) def test_statistic_instance_attributes( self, factory: Callable[[], MockStatisticInput] ) -> None: @@ -269,8 +277,8 @@ def test_statistic_instance_attributes( assert statistic.scale == (mock_inputs.config.get("scale") is not None) # `scale_func` attribute - if scale_func := mock_inputs.config.get("scale") is not None: - assert statistic.scale_func == scale_func + if (scale_func := mock_inputs.config.get("scale")) is not None: + assert statistic.scale_func == getattr(np, scale_func) # `sim_var` attribute assert statistic.sim_var == mock_inputs.config["sim_var"] @@ -278,7 +286,7 @@ def test_statistic_instance_attributes( # `zero_to_one` attribute assert statistic.zero_to_one == mock_inputs.config.get("zero_to_one", False) - @pytest.mark.parametrize("factory", [(simple_valid_factory)]) + @pytest.mark.parametrize("factory", all_valid_factories) def test_statistic_str_and_repr( self, factory: Callable[[], MockStatisticInput] ) -> None: @@ -323,9 +331,7 @@ def test_allsubpop_regularize( ) assert isinstance(forecast_regularization, float) - @pytest.mark.parametrize( - "factory", [(simple_valid_factory), (simple_valid_resample_factory)] - ) + @pytest.mark.parametrize("factory", all_valid_factories) def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None: # Setup mock_inputs = factory() @@ -353,9 +359,7 @@ def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None # No resample config, `apply_resample` returns our input assert resampled_data.identical(mock_inputs.model_data) - @pytest.mark.parametrize( - "factory", [(simple_valid_factory), (simple_valid_scale_factory)] - ) + @pytest.mark.parametrize("factory", all_valid_factories) def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: # Setup mock_inputs = factory() @@ -371,15 +375,7 @@ def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: # No scale config, `apply_scale` is a no-op assert scaled_data.identical(mock_inputs.model_data) - @pytest.mark.parametrize( - "factory", - [ - (simple_valid_factory), - (simple_valid_resample_factory), - (simple_valid_scale_factory), - (simple_valid_resample_and_scale_factory), - ], - ) + @pytest.mark.parametrize("factory", all_valid_factories) def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> None: # Setup mock_inputs = factory() From f94ba0f12534110a3d577094123fb0386743d959 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Thu, 22 Aug 2024 11:39:35 -0400 Subject: [PATCH 14/18] Add unit tests for `Statistic.llik` Added unit tests for the `llik` method of the `Statistic` class. Had to change factories to use RMSE by default for likelihood distribution since the poisson distribution only has integer support. --- .../tests/statistics/test_statistic_class.py | 61 +++++++++++++++++-- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index e466af0a5..489dde515 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +import scipy import xarray as xr from gempyor.statistics import Statistic @@ -47,7 +48,7 @@ def invalid_regularization_factory() -> MockStatisticInput: "data_var": "incidH", "remove_na": True, "add_one": True, - "likelihood": {"dist": "pois"}, + "likelihood": {"dist": "rmse"}, "regularize": [{"name": "forecast"}, {"name": "invalid"}], }, ) @@ -80,7 +81,7 @@ def simple_valid_factory() -> MockStatisticInput: "data_var": "incidH", "remove_na": True, "add_one": True, - "likelihood": {"dist": "pois"}, + "likelihood": {"dist": "norm", "params": {"scale": 2.0}}, }, model_data=model_data, gt_data=gt_data, @@ -117,7 +118,7 @@ def simple_valid_resample_factory() -> MockStatisticInput: "data_var": "incidH", "remove_na": True, "add_one": True, - "likelihood": {"dist": "pois"}, + "likelihood": {"dist": "rmse"}, "resample": {"freq": "MS", "aggregator": "sum"}, }, model_data=model_data, @@ -155,7 +156,7 @@ def simple_valid_scale_factory() -> MockStatisticInput: "data_var": "incidH", "remove_na": True, "add_one": True, - "likelihood": {"dist": "pois"}, + "likelihood": {"dist": "rmse"}, "scale": "exp", }, model_data=model_data, @@ -193,7 +194,7 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput: "data_var": "incidH", "remove_na": True, "add_one": True, - "likelihood": {"dist": "pois"}, + "likelihood": {"dist": "rmse"}, "resample": {"freq": "W", "aggregator": "max"}, "scale": "sin", }, @@ -405,3 +406,53 @@ def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> No expected_transformed_data ) assert transformed_data.identical(expected_transformed_data) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + log_likelihood = statistic.llik(mock_inputs.model_data, mock_inputs.gt_data) + + assert isinstance(log_likelihood, xr.DataArray) + assert log_likelihood.dims == mock_inputs.gt_data.dims + assert log_likelihood.coords.equals(mock_inputs.gt_data.coords) + dist_name = mock_inputs.config["likelihood"]["dist"] + if dist_name in {"absolute_error", "rmse"}: + # MAE produces a single repeated number + assert np.allclose( + log_likelihood.values, + -np.log( + np.nansum(np.abs(mock_inputs.model_data - mock_inputs.gt_data)) + ), + ) + elif dist_name == "pois": + assert np.allclose( + log_likelihood.values, + scipy.stats.poisson.logpmf( + mock_inputs.gt_data.values, mock_inputs.model_data.values + ), + ) + elif dist_name == {"norm", "norm_cov"}: + scale = mock_inputs.config["likelihood"]["params"]["scale"] + if dist_name == "norm_cov": + scale *= mock_inputs.model_data.where(mock_inputs.model_data > 5, 5) + assert np.allclose( + log_likelihood.values, + scipy.stats.norm.logpdf( + mock_inputs.gt_data.values, + mock_inputs.model_data.values, + scale=scale, + ), + ) + elif dist_name == "nbinom": + assert np.allclose( + log_likelihood.values, + scipy.stats.nbinom.logpmf( + mock_inputs.gt_data.values, + n=mock_inputs.config["likelihood"]["params"]["n"], + p=mock_inputs.model_data.values, + ), + ) From 52bad21617d9af4d51322adef12e8816ecf02067 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:43:14 -0400 Subject: [PATCH 15/18] Change model/gt data to use xarray Dataset Was previously using `xarray.DataArray` for `model_data` and `gt_data` in unit testing the `Statistic` class since that is what many methods expect. It seems though the main entry to the class, `compute_logloss` takes an `xarray.DataSet` that the class splices into `xarray.DataArray`s. The unit tests now more accurately reflect this. --- .../tests/statistics/test_statistic_class.py | 214 +++++++++++------- 1 file changed, 135 insertions(+), 79 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 489dde515..6b15014eb 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -1,4 +1,5 @@ from datetime import date +from itertools import product from typing import Any, Callable import confuse @@ -55,21 +56,24 @@ def invalid_regularization_factory() -> MockStatisticInput: def simple_valid_factory() -> MockStatisticInput: - model_data = xr.DataArray( - data=np.random.randn(10, 3), - dims=("date", "subpop"), - coords={ - "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), - "subpop": ["01", "02", "03"], + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) - gt_data = xr.DataArray( - data=np.random.randn(10, 3), - dims=("date", "subpop"), - coords={ - "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), - "subpop": ["01", "02", "03"], + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) return MockStatisticInput( "total_hospitalizations", @@ -89,24 +93,24 @@ def simple_valid_factory() -> MockStatisticInput: def simple_valid_resample_factory() -> MockStatisticInput: - date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) - subpop_coords = ["01", "02", "03", "04"] - dim = (len(date_coords), len(subpop_coords)) - model_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) - gt_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) return MockStatisticInput( "total_hospitalizations", @@ -127,24 +131,24 @@ def simple_valid_resample_factory() -> MockStatisticInput: def simple_valid_scale_factory() -> MockStatisticInput: - date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) - subpop_coords = ["01", "02", "03", "04"] - dim = (len(date_coords), len(subpop_coords)) - model_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) - gt_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) return MockStatisticInput( "total_hospitalizations", @@ -165,24 +169,24 @@ def simple_valid_scale_factory() -> MockStatisticInput: def simple_valid_resample_and_scale_factory() -> MockStatisticInput: - date_coords = pd.date_range(date(2024, 1, 1), date(2024, 12, 31)) - subpop_coords = ["01", "02", "03", "04"] - dim = (len(date_coords), len(subpop_coords)) - model_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + data_coords = { + "date": pd.date_range(date(2024, 1, 1), date(2024, 12, 31)), + "subpop": ["01", "02", "03", "04"], + } + data_dim = [len(v) for v in data_coords.values()] + model_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) - gt_data = xr.DataArray( - data=np.random.randn(*dim), - dims=("date", "subpop"), - coords={ - "date": date_coords, - "subpop": subpop_coords, + gt_data = xr.Dataset( + data_vars={ + "incidH": (list(data_coords.keys()), np.random.randn(*data_dim)), + "incidD": (list(data_coords.keys()), np.random.randn(*data_dim)), }, + coords=data_coords, ) return MockStatisticInput( "total_hospitalizations", @@ -314,7 +318,10 @@ def test_forecast_regularize( # Tests forecast_regularization = statistic._forecast_regularize( - mock_inputs.model_data, mock_inputs.gt_data, last_n=last_n, mult=mult + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + last_n=last_n, + mult=mult, ) assert isinstance(forecast_regularization, float) @@ -328,7 +335,9 @@ def test_allsubpop_regularize( # Tests forecast_regularization = statistic._allsubpop_regularize( - mock_inputs.model_data, mock_inputs.gt_data, mult=mult + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + mult=mult, ) assert isinstance(forecast_regularization, float) @@ -339,12 +348,14 @@ def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None statistic = mock_inputs.create_statistic_instance() # Tests - resampled_data = statistic.apply_resample(mock_inputs.model_data) + resampled_data = statistic.apply_resample( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) if resample_config := mock_inputs.config.get("resample", {}): # Resample config - expected_resampled_data = mock_inputs.model_data.resample( - date=resample_config.get("freq", "") - ) + expected_resampled_data = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].resample(date=resample_config.get("freq", "")) aggregation_func = getattr( expected_resampled_data, resample_config.get("aggregator", "") ) @@ -358,7 +369,9 @@ def test_apply_resample(self, factory: Callable[[], MockStatisticInput]) -> None assert resampled_data.identical(expected_resampled_data) else: # No resample config, `apply_resample` returns our input - assert resampled_data.identical(mock_inputs.model_data) + assert resampled_data.identical( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) @pytest.mark.parametrize("factory", all_valid_factories) def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: @@ -367,14 +380,20 @@ def test_apply_scale(self, factory: Callable[[], MockStatisticInput]) -> None: statistic = mock_inputs.create_statistic_instance() # Tests - scaled_data = statistic.apply_scale(mock_inputs.model_data) + scaled_data = statistic.apply_scale( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) if (scale_func := mock_inputs.config.get("scale")) is not None: # Scale config - expected_scaled_data = getattr(np, scale_func)(mock_inputs.model_data) + expected_scaled_data = getattr(np, scale_func)( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) assert scaled_data.identical(expected_scaled_data) else: # No scale config, `apply_scale` is a no-op - assert scaled_data.identical(mock_inputs.model_data) + assert scaled_data.identical( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) @pytest.mark.parametrize("factory", all_valid_factories) def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> None: @@ -383,8 +402,12 @@ def test_apply_transforms(self, factory: Callable[[], MockStatisticInput]) -> No statistic = mock_inputs.create_statistic_instance() # Tests - transformed_data = statistic.apply_transforms(mock_inputs.model_data) - expected_transformed_data = mock_inputs.model_data.copy() + transformed_data = statistic.apply_transforms( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + ) + expected_transformed_data = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].copy() if resample_config := mock_inputs.config.get("resample", {}): # Resample config expected_transformed_data = expected_transformed_data.resample( @@ -414,36 +437,52 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: statistic = mock_inputs.create_statistic_instance() # Tests - log_likelihood = statistic.llik(mock_inputs.model_data, mock_inputs.gt_data) + log_likelihood = statistic.llik( + mock_inputs.model_data[mock_inputs.config["sim_var"]], + mock_inputs.gt_data[mock_inputs.config["data_var"]], + ) assert isinstance(log_likelihood, xr.DataArray) - assert log_likelihood.dims == mock_inputs.gt_data.dims - assert log_likelihood.coords.equals(mock_inputs.gt_data.coords) + assert ( + log_likelihood.dims + == mock_inputs.gt_data[mock_inputs.config["data_var"]].dims + ) + assert log_likelihood.coords.equals( + mock_inputs.gt_data[mock_inputs.config["data_var"]].coords + ) dist_name = mock_inputs.config["likelihood"]["dist"] if dist_name in {"absolute_error", "rmse"}: # MAE produces a single repeated number assert np.allclose( log_likelihood.values, -np.log( - np.nansum(np.abs(mock_inputs.model_data - mock_inputs.gt_data)) + np.nansum( + np.abs( + mock_inputs.model_data[mock_inputs.config["sim_var"]] + - mock_inputs.gt_data[mock_inputs.config["data_var"]] + ) + ) ), ) elif dist_name == "pois": assert np.allclose( log_likelihood.values, scipy.stats.poisson.logpmf( - mock_inputs.gt_data.values, mock_inputs.model_data.values + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + mock_inputs.model_data[mock_inputs.config["data_var"]].values, ), ) elif dist_name == {"norm", "norm_cov"}: scale = mock_inputs.config["likelihood"]["params"]["scale"] if dist_name == "norm_cov": - scale *= mock_inputs.model_data.where(mock_inputs.model_data > 5, 5) + scale *= mock_inputs.model_data[mock_inputs.config["sim_var"]].where( + mock_inputs.model_data[mock_inputs.config["sim_var"]] > 5, 5 + ) assert np.allclose( log_likelihood.values, scipy.stats.norm.logpdf( - mock_inputs.gt_data.values, - mock_inputs.model_data.values, + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, + mock_inputs.model_data[mock_inputs.config["sim_var"]].values, scale=scale, ), ) @@ -451,8 +490,25 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: assert np.allclose( log_likelihood.values, scipy.stats.nbinom.logpmf( - mock_inputs.gt_data.values, + mock_inputs.gt_data[mock_inputs.config["data_var"]].values, n=mock_inputs.config["likelihood"]["params"]["n"], - p=mock_inputs.model_data.values, + p=mock_inputs.model_data[mock_inputs.config["sim_var"]].values, ), ) + + @pytest.mark.parametrize("factory", all_valid_factories) + def test_compute_logloss(self, factory: Callable[[], MockStatisticInput]) -> None: + # Setup + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + # Tests + log_likelihood, regularization = statistic.compute_logloss( + mock_inputs.model_data, mock_inputs.gt_data + ) + + assert True + + # print(regularization) + + # assert isinstance(regularization, float) From 02c9dc4a3e71749c4ee8a14617a318dfcde1f771 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:29:43 -0400 Subject: [PATCH 16/18] Initial unit tests on `Statistic.compute_logloss` * Created initial unit tests on the `compute_logloss` method of `Statistic`, checking for structure but not correctness. * Updated documentation for `compute_logloss` to reflect the possible `ValueError` and the correct input types expected. * Changed internal variable of that method to a float to get a consistent float return for the second tuple entry from `compute_logloss`. --- .../gempyor_pkg/src/gempyor/statistics.py | 7 +++-- .../tests/statistics/test_statistic_class.py | 26 ++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index fd1dd1a43..ea10e7aa1 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -271,7 +271,7 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray: return likelihood def compute_logloss( - self, model_data: xr.DataArray, gt_data: xr.DataArray + self, model_data: xr.Dataset, gt_data: xr.Dataset ) -> tuple[xr.DataArray, float]: """ Compute the logistic loss of observing the ground truth given model output. @@ -286,6 +286,9 @@ def compute_logloss( The logistic loss of observing `gt_data` from the model `model_data` decomposed into the log-likelihood along the "subpop" dimension and regularizations. + + Raises: + ValueError: If `model_data` and `gt_data` do not have the same shape. """ model_data = self.apply_transforms(model_data[self.sim_var]) gt_data = self.apply_transforms(gt_data[self.data_var]) @@ -299,7 +302,7 @@ def compute_logloss( ) ) - regularization = 0 + regularization = 0.0 for reg_func, reg_config in self.regularizations: regularization += reg_func( model_data=model_data, gt_data=gt_data, **reg_config diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index 6b15014eb..fb92afe4b 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -194,8 +194,8 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput: "name": "sum_hospitalizations", "aggregator": "sum", "period": "1 months", - "sim_var": "incidH", - "data_var": "incidH", + "sim_var": "incidD", + "data_var": "incidD", "remove_na": True, "add_one": True, "likelihood": {"dist": "rmse"}, @@ -447,7 +447,7 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: log_likelihood.dims == mock_inputs.gt_data[mock_inputs.config["data_var"]].dims ) - assert log_likelihood.coords.equals( + assert log_likelihood.coords.identical( mock_inputs.gt_data[mock_inputs.config["data_var"]].coords ) dist_name = mock_inputs.config["likelihood"]["dist"] @@ -501,14 +501,22 @@ def test_compute_logloss(self, factory: Callable[[], MockStatisticInput]) -> Non # Setup mock_inputs = factory() statistic = mock_inputs.create_statistic_instance() - - # Tests log_likelihood, regularization = statistic.compute_logloss( mock_inputs.model_data, mock_inputs.gt_data ) + regularization_config = mock_inputs.config.get("regularize", []) - assert True - - # print(regularization) + # Assertions on log_likelihood + assert isinstance(log_likelihood, xr.DataArray) + assert log_likelihood.coords.identical( + xr.Coordinates(coords={"subpop": mock_inputs.gt_data.coords.get("subpop")}) + ) - # assert isinstance(regularization, float) + # Assertions on regularization + assert isinstance(regularization, float) + if regularization_config: + # Regularizations on logistic loss + assert regularization != 0.0 + else: + # No regularizations on logistic loss + assert regularization == 0.0 From 4ff66823f705d9d1b3308d46d79eb6cf81e585e1 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:38:06 -0400 Subject: [PATCH 17/18] Test fixture for data misshape `ValueError` Added a test fixture that confirms the `ValueError` raised when model data and ground truth data do not have the same shapes in `Statistic.compute_logloss`. --- .../tests/statistics/test_statistic_class.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index fb92afe4b..b7258bad2 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -55,6 +55,36 @@ def invalid_regularization_factory() -> MockStatisticInput: ) +def invalid_misshaped_data_factory() -> MockStatisticInput: + model_data = xr.Dataset( + data_vars={"incidH": (["date", "subpop"], np.random.randn(10, 3))}, + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), + "subpop": ["01", "02", "03"], + }, + ) + gt_data = xr.Dataset( + data_vars={"incidH": (["date", "subpop"], np.random.randn(11, 2))}, + coords={ + "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 11)), + "subpop": ["02", "03"], + }, + ) + return MockStatisticInput( + "total_hospitalizations", + { + "name": "sum_hospitalizations", + "sim_var": "incidH", + "data_var": "incidH", + "remove_na": True, + "add_one": True, + "likelihood": {"dist": "norm", "params": {"scale": 2.0}}, + }, + model_data=model_data, + gt_data=gt_data, + ) + + def simple_valid_factory() -> MockStatisticInput: data_coords = { "date": pd.date_range(date(2024, 1, 1), date(2024, 1, 10)), @@ -496,6 +526,25 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None: ), ) + @pytest.mark.parametrize("factory", [(invalid_misshaped_data_factory)]) + def test_compute_logloss_data_misshape_value_error( + self, factory: Callable[[], MockStatisticInput] + ) -> None: + mock_inputs = factory() + statistic = mock_inputs.create_statistic_instance() + + model_rows, model_cols = mock_inputs.model_data[ + mock_inputs.config["sim_var"] + ].shape + gt_rows, gt_cols = mock_inputs.gt_data[mock_inputs.config["data_var"]].shape + expected_match = ( + rf"^{mock_inputs.name} Statistic error\: data and groundtruth do not have " + rf"the same shape\: model\_data\.shape\=\({model_rows}\, {model_cols}\) " + rf"\!\= gt\_data\.shape\=\({gt_rows}\, {gt_cols}\)$" + ) + with pytest.raises(ValueError, match=expected_match): + statistic.compute_logloss(mock_inputs.model_data, mock_inputs.gt_data) + @pytest.mark.parametrize("factory", all_valid_factories) def test_compute_logloss(self, factory: Callable[[], MockStatisticInput]) -> None: # Setup From 15864c760f1738a2625995bc936d11a29540cc19 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:01:47 -0400 Subject: [PATCH 18/18] Remove unnecessary entries from mock configs There were entries in the mock configs, modeled on existing configs, that are not considered by the `Statistic` class at all. Removed for clarity. --- .../tests/statistics/test_statistic_class.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py index b7258bad2..e18861e9e 100644 --- a/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py +++ b/flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py @@ -43,8 +43,6 @@ def invalid_regularization_factory() -> MockStatisticInput: "total_hospitalizations", { "name": "sum_hospitalizations", - "aggregator": "sum", - "period": "1 months", "sim_var": "incidH", "data_var": "incidH", "remove_na": True, @@ -109,8 +107,6 @@ def simple_valid_factory() -> MockStatisticInput: "total_hospitalizations", { "name": "sum_hospitalizations", - "aggregator": "sum", - "period": "1 months", "sim_var": "incidH", "data_var": "incidH", "remove_na": True, @@ -146,8 +142,6 @@ def simple_valid_resample_factory() -> MockStatisticInput: "total_hospitalizations", { "name": "sum_hospitalizations", - "aggregator": "sum", - "period": "1 months", "sim_var": "incidH", "data_var": "incidH", "remove_na": True, @@ -184,8 +178,6 @@ def simple_valid_scale_factory() -> MockStatisticInput: "total_hospitalizations", { "name": "sum_hospitalizations", - "aggregator": "sum", - "period": "1 months", "sim_var": "incidH", "data_var": "incidH", "remove_na": True, @@ -222,8 +214,6 @@ def simple_valid_resample_and_scale_factory() -> MockStatisticInput: "total_hospitalizations", { "name": "sum_hospitalizations", - "aggregator": "sum", - "period": "1 months", "sim_var": "incidD", "data_var": "incidD", "remove_na": True,