Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document/Unit Test gempyor.statistics #304

Merged
merged 24 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c2f423e
Draft documentation for `gempyor.statistics`
TimothyWillard Aug 19, 2024
cc3a691
Applied black formatter
TimothyWillard Aug 19, 2024
867924a
Type annotations, black formatter
TimothyWillard Aug 20, 2024
cc8cd79
Merge unit-test-gempyor-parameters into GH-300/statistics-unit-tests-…
TimothyWillard Aug 20, 2024
8fb81cd
Initial unit test infra for `Statistic` class
TimothyWillard Aug 20, 2024
b27ec52
Add `Statistic` attributes test fixture
TimothyWillard Aug 20, 2024
f1b559c
Add fixture for `str` and `repr` of `Statistic`
TimothyWillard Aug 20, 2024
b8891be
Corrected `llik` return type hint
TimothyWillard Aug 20, 2024
239ced6
Move TODO comments to GH-300
TimothyWillard Aug 21, 2024
f5a9cb9
Initial tests for Statistic regularization
TimothyWillard Aug 21, 2024
878c539
Unit tests for `Statistic.apply_resample`
TimothyWillard Aug 21, 2024
31368be
Added unit tests for `Statistic.apply_scale`
TimothyWillard Aug 21, 2024
84f3632
Added unit test for `Statistic.apply_transforms`
TimothyWillard Aug 21, 2024
a518f36
Consolidate valid factories into global var
TimothyWillard Aug 21, 2024
f94ba0f
Add unit tests for `Statistic.llik`
TimothyWillard Aug 22, 2024
52bad21
Change model/gt data to use xarray Dataset
TimothyWillard Aug 23, 2024
02c9dc4
Initial unit tests on `Statistic.compute_logloss`
TimothyWillard Aug 23, 2024
4ff6682
Test fixture for data misshape `ValueError`
TimothyWillard Aug 23, 2024
15864c7
Remove unnecessary entries from mock configs
TimothyWillard Aug 23, 2024
f0a2a0b
Merge unit-test-gempyor-parameters into GH-300/statistics-unit-tests-…
TimothyWillard Sep 9, 2024
4dd8683
Merge unit-test-gempyor-parameters into GH-300/statistics-unit-tests-…
TimothyWillard Sep 13, 2024
68b99d2
Merge main into GH-300/statistics-unit-tests-docs
TimothyWillard Oct 10, 2024
42d2f9d
Merge branch 'main' into GH-300/statistics-unit-tests-docs
TimothyWillard Oct 14, 2024
54de3bc
Merge main into GH-300/statistics-unit-tests-docs
TimothyWillard Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 202 additions & 43 deletions flepimop/gempyor_pkg/src/gempyor/statistics.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
import xarray as xr
import pandas as pd
import numpy as np
"""
Abstractions for interacting with output statistic configurations.
This module provides the `Statistic` class which represents a entry in the inference ->
statistics section.
"""

__all__ = ["Statistic"]


import confuse
import numpy as np
import scipy.stats
import xarray as xr


class Statistic:
"""
A statistic is a function that takes two time series and returns a scalar value.
It applies resample, scale, and regularization to the data before computing the statistic's log-loss.
Configuration:
- sim_var: the variable in the simulation data
- data_var: the variable in the ground truth data
- resample: resample the data before computing the statistic
- freq: the frequency to resample the data to
- aggregator: the aggregation function to use
- skipna: whether to skip NA values
- regularize: apply a regularization term to the data before computing the statistic
# SkipNA is False by default, which results in NA values broadcasting when resampling (e.g a NA withing a sum makes the whole sum a NA)
# if True, then NA are replaced with 0 (for sum), 1 for product, ...
# In doubt, plot stat.plot_transformed() to see the effect of the resampling
Encapsulates logic for representing/implementing output statistic configurations.
A statistic is a function that takes two time series and returns a scalar value. It
applies resample, scale, and regularization to the data before computing the
statistic's log-loss.
Attributes:
data_var: The variable in the ground truth data.
dist: The name of the distribution to use for calculating log-likelihood.
name: The human readable name for the statistic given during instantiation.
params: Distribution parameters used in the log-likelihood calculation and
dependent on `dist`.
regularizations: Regularization functions that are added to the log loss of this
statistic.
resample: If the data should be resampled before computing the statistic.
Defaults to `False`.
resample_aggregator_name: The name of the aggregation function to use. This
attribute is not set when a "resample" section is not defined in the
`statistic_config` arg.
resample_freq: The frequency to resample the data to if the `resample` attribute
is `True`. This attribute is not set when a "resample" section is not
defined in the `statistic_config` arg.
resample_skipna: If NAs should be skipped when aggregating. `False` by default.
This attribute is not set when a "resample" section is not defined in the
`statistic_config` arg.
scale: If the data should be rescaled before computing the statistic.
scale_func: The function to use when rescaling the data. Can be any function
exported by `numpy`. This attribute is not set when a "scale" value is not
defined in the `statistic_config` arg.
zero_to_one: Should non-zero values be coerced to 1 when calculating
log-likelihood.
"""

def __init__(self, name, statistic_config: confuse.ConfigView):
def __init__(self, name: str, statistic_config: confuse.ConfigView) -> None:
"""
Create an `Statistic` instance from a confuse config view.
Args:
name: A human readable name for the statistic, mostly used for error
messages.
statistic_config: A confuse configuration view object describing an output
statistic.
Raises:
ValueError: If an unsupported regularization name is provided via the
`statistic_config` arg. Currently only 'forecast' and 'allsubpop' are
supported.
"""
self.sim_var = statistic_config["sim_var"].as_str()
self.data_var = statistic_config["data_var"].as_str()
self.name = name
Expand All @@ -32,7 +72,7 @@ def __init__(self, name, statistic_config: confuse.ConfigView):
if statistic_config["regularize"].exists():
for reg_config in statistic_config["regularize"]: # Iterate over the list
reg_name = reg_config["name"].get()
reg_func = getattr(self, f"_{reg_name}_regularize")
reg_func = getattr(self, f"_{reg_name}_regularize", None)
if reg_func is None:
raise ValueError(f"Unsupported regularization: {reg_name}")
self.regularizations.append((reg_func, reg_config.get()))
Expand All @@ -47,12 +87,16 @@ def __init__(self, name, statistic_config: confuse.ConfigView):
self.resample_aggregator = ""
if resample_config["aggregator"].exists():
self.resample_aggregator_name = resample_config["aggregator"].get()
self.resample_skipna = False # TODO
if resample_config["aggregator"].exists() and resample_config["skipna"].exists():
self.resample_skipna = False
if (
resample_config["aggregator"].exists()
and resample_config["skipna"].exists()
):
self.resample_skipna = resample_config["skipna"].get()

self.scale = False
if statistic_config["scale"].exists():
self.scale = True
self.scale_func = getattr(np, statistic_config["scale"].get())

self.dist = statistic_config["likelihood"]["dist"].get()
Expand All @@ -62,58 +106,149 @@ def __init__(self, name, statistic_config: confuse.ConfigView):
self.params = {}

self.zero_to_one = False
# TODO: this should be set_zeros_to and only do it for the probabilily
if statistic_config["zero_to_one"].exists():
self.zero_to_one = statistic_config["zero_to_one"].get()

def _forecast_regularize(self, model_data, gt_data, **kwargs):
# scale the data so that the lastest X items are more important
def __str__(self) -> str:
return (
f"{self.name}: {self.dist} between {self.sim_var} "
f"(sim) and {self.data_var} (data)."
)

def __repr__(self) -> str:
return f"A Statistic(): {self.__str__()}"

def _forecast_regularize(
self,
model_data: xr.DataArray,
gt_data: xr.DataArray,
**kwargs: dict[str, int | float],
) -> float:
"""
Regularization function to add weight to more recent forecasts.
Args:
model_data: An xarray Dataset of the model data with date and subpop
dimensions.
gt_data: An xarray Dataset of the ground truth data with date and subpop
dimensions.
**kwargs: Optional keyword arguments that influence regularization.
Currently uses `last_n` for the number of observations to up weight and
`mult` for the coefficient of the regularization value.
Returns:
The log-likelihood of the `last_n` observation up weighted by a factor of
`mult`.
"""
# scale the data so that the latest X items are more important
last_n = kwargs.get("last_n", 4)
mult = kwargs.get("mult", 2)

last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None)))
last_n_llik = self.llik(
model_data.isel(date=slice(-last_n, None)),
gt_data.isel(date=slice(-last_n, None)),
)

return mult * last_n_llik.sum().sum().values

def _allsubpop_regularize(self, model_data, gt_data, **kwargs):
"""add a regularization term that is the sum of all subpopulations"""
def _allsubpop_regularize(
self,
model_data: xr.DataArray,
gt_data: xr.DataArray,
**kwargs: dict[str, int | float],
) -> float:
"""
Regularization function to add the sum of all subpopulations.
Args:
model_data: An xarray Dataset of the model data with date and subpop
dimensions.
gt_data: An xarray Dataset of the ground truth data with date and subpop
dimensions.
**kwargs: Optional keyword arguments that influence regularization.
Currently uses `mult` for the coefficient of the regularization value.
Returns:
The sum of the subpopulations multiplied by `mult`.
"""
mult = kwargs.get("mult", 1)
llik_total = self.llik(model_data.sum("subpop"), gt_data.sum("subpop"))
return mult * llik_total.sum().sum().values

def __str__(self) -> str:
return f"{self.name}: {self.dist} between {self.sim_var} (sim) and {self.data_var} (data)."
def apply_resample(self, data: xr.DataArray) -> xr.DataArray:
"""
Resample a data set to the given frequency using the specified aggregation.
def __repr__(self) -> str:
return f"A Statistic(): {self.__str__()}"
Args:
data: An xarray dataset with "date" and "subpop" dimensions.
def apply_resample(self, data):
Returns:
A resample dataset with similar dimensions to `data`.
"""
if self.resample:
aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name)
aggregator_method = getattr(
data.resample(date=self.resample_freq), self.resample_aggregator_name
)
return aggregator_method(skipna=self.resample_skipna)
else:
return data

def apply_scale(self, data):
def apply_scale(self, data: xr.DataArray) -> xr.DataArray:
"""
Scale a data set using the specified scaling function.
Args:
data: An xarray dataset with "date" and "subpop" dimensions.
Returns:
An xarray dataset of the same shape and dimensions as `data` with the
`scale_func` attribute applied.
"""
if self.scale:
return self.scale_func(data)
else:
return data

def apply_transforms(self, data):
def apply_transforms(self, data: xr.DataArray):
"""
Convenient wrapper for resampling and scaling a data set.
The resampling is applied *before* scaling which can affect the log-likelihood.
Args:
data: An xarray dataset with "date" and "subpop" dimensions.
Returns:
An scaled and resampled dataset with similar dimensions to `data`.
"""
data_scaled_resampled = self.apply_scale(self.apply_resample(data))
return data_scaled_resampled

def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray):
def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray) -> xr.DataArray:
"""
Compute the log-likelihood of observing the ground truth given model output.
Args:
model_data: An xarray Dataset of the model data with date and subpop
dimensions.
gt_data: An xarray Dataset of the ground truth data with date and subpop
dimensions.
Returns:
The log-likelihood of observing `gt_data` from the model `model_data` as an
xarray DataArray with a "subpop" dimension.
"""
dist_map = {
"pois": scipy.stats.poisson.logpmf,
"norm": lambda x, loc, scale: scipy.stats.norm.logpdf(
x, loc=loc, scale=self.params.get("scale", scale)
), # wrong:
"norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf(
x, loc=loc, scale=scale * loc.where(loc > 5, 5)
), # TODO: check, that it's really the loc
"nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(x, n=self.params.get("n"), p=model_data),
),
"nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(
x, n=self.params.get("n"), p=model_data
),
"rmse": lambda x, y: -np.log(np.nansum(np.sqrt((x - y) ** 2))),
"absolute_error": lambda x, y: -np.log(np.nansum(np.abs(x - y))),
}
Expand All @@ -133,20 +268,44 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray):

likelihood = xr.DataArray(likelihood, coords=gt_data.coords, dims=gt_data.dims)

# TODO: check the order of the arguments
return likelihood

def compute_logloss(self, model_data, gt_data):
def compute_logloss(
self, model_data: xr.Dataset, gt_data: xr.Dataset
) -> tuple[xr.DataArray, float]:
"""
Compute the logistic loss of observing the ground truth given model output.
Args:
model_data: An xarray Dataset of the model data with date and subpop
dimensions.
gt_data: An xarray Dataset of the ground truth data with date and subpop
dimensions.
Returns:
The logistic loss of observing `gt_data` from the model `model_data`
decomposed into the log-likelihood along the "subpop" dimension and
regularizations.
Raises:
ValueError: If `model_data` and `gt_data` do not have the same shape.
"""
model_data = self.apply_transforms(model_data[self.sim_var])
gt_data = self.apply_transforms(gt_data[self.data_var])

if not model_data.shape == gt_data.shape:
raise ValueError(
f"{self.name} Statistic error: data and groundtruth do not have the same shape: model_data.shape={model_data.shape} != gt_data.shape={gt_data.shape}"
(
f"{self.name} Statistic error: data and groundtruth do not have "
f"the same shape: model_data.shape={model_data.shape} != "
f"gt_data.shape={gt_data.shape}"
)
)

regularization = 0
regularization = 0.0
for reg_func, reg_config in self.regularizations:
regularization += reg_func(model_data=model_data, gt_data=gt_data, **reg_config) # Pass config parameters
regularization += reg_func(
model_data=model_data, gt_data=gt_data, **reg_config
) # Pass config parameters

return self.llik(model_data, gt_data).sum("date"), regularization
Loading
Loading