-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'new_inference' into time2date
- Loading branch information
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import xarray as xr | ||
import pandas as pd | ||
import numpy as np | ||
import confuse | ||
import scipy.stats | ||
from . import statistics | ||
|
||
|
||
## https://docs.xarray.dev/en/stable/user-guide/indexing.html#assigning-values-with-indexing | ||
# TODO: add an autatic test that show that the loss is biggest when gt == modeldata | ||
|
||
class LogLoss: | ||
def __init__(self, inference_config: confuse.ConfigView, data_dir:str = "."): | ||
self.gt = pd.read_csv(f"{data_dir}/{inference_config['gt_data_path'].get()}") | ||
self.gt["date"] = pd.to_datetime(self.gt['date']) | ||
self.gt = self.gt.set_index("date") | ||
self.statistics = {} | ||
for key, value in inference_config["statistics"].items(): | ||
self.statistics[key] = statistics.Statistic(key, value) | ||
|
||
def plot_gt(self, ax, subpop, statistic, **kwargs): | ||
self.gt[self.gt["subpop"] == subpop].plot(y=statistic, ax=ax, **kwargs) | ||
|
||
|
||
def compute_logloss(self, model_df, modinf): | ||
""" | ||
Compute logloss for all statistics | ||
model_df: DataFrame indexed by date | ||
modinf: model information | ||
TODO: support kwargs for emcee, and this looks very slow | ||
""" | ||
logloss = xr.DataArray(0, dims=["statistic", "subpop"], | ||
coords={ | ||
"statistic":self.statistics.key(), | ||
"subpop":modinf.subpop_struct.subpop_names}) | ||
|
||
for subpop in modinf.subpop_struct.subpop_names: | ||
# essential to sort by index (date here) | ||
gt_s = self.gt[self.gt["subpop"] == subpop].sort_index() | ||
model_df_s = model_df[model_df["subpop"] == subpop].sort_index() | ||
|
||
# Places where data and model overlap | ||
first_date = max(gt_s.index.min(), model_df_s.index.min()) | ||
last_date = min(gt_s.index.max(), model_df_s.index.max()) | ||
|
||
gt_s = gt_s.loc[first_date:last_date].drop(["subpop"], axis=1) | ||
model_df_s = model_df_s.drop(["subpop"], axis=1).loc[first_date:last_date] | ||
|
||
# TODO: add whole US!! option | ||
|
||
for key, stat in self.statistics.items(): | ||
logloss.loc[dict(statistics=key, subpop=subpop)] += stat.compute_logloss(model_df, gt_s) | ||
|
||
return logloss | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import xarray as xr | ||
import pandas as pd | ||
import numpy as np | ||
import confuse | ||
import scipy.stats | ||
|
||
class Statistic: | ||
""" | ||
A statistic is a function that takes two time series and returns a scalar value. | ||
It applies resample, scale, and regularization to the data before computing the statistic's log-loss. | ||
Configuration: | ||
- sim_var: the variable in the simulation data | ||
- data_var: the variable in the ground truth data | ||
- resample: resample the data before computing the statistic | ||
- freq: the frequency to resample the data to | ||
- aggregator: the aggregation function to use | ||
- skipna: whether to skip NA values | ||
- regularize: apply a regularization term to the data before computing the statistic | ||
""" | ||
def __init__(self, name, statistic_config: confuse.ConfigView): | ||
self.sim_var = statistic_config["sim_var"].as_str() | ||
self.data_var = statistic_config["data_var"].as_str() | ||
self.name = name | ||
|
||
self.regularizations = [] # A list to hold regularization functions and configs | ||
if statistic_config["regularize"].exists(): | ||
for reg_config in statistic_config["regularize"]: # Iterate over the list | ||
reg_name = reg_config["name"].get() | ||
reg_func = getattr(self, f"_{reg_name}_regularize") | ||
if reg_func is None: | ||
raise ValueError(f"Unsupported regularization: {reg_name}") | ||
self.regularizations.append((reg_func, reg_config)) | ||
|
||
self.resample = False | ||
if statistic_config["resample"].exists(): | ||
self.resample = True | ||
resample_config = statistic_config["resample"] | ||
self.resample_freq = "" | ||
if resample_config["freq"].exists(): | ||
self.resample_freq = resample_config["freq"].get() | ||
self.resample_aggregator = "" | ||
if resample_config["aggregator"].exists(): | ||
self.resample_aggregator = getattr(pd.Series, resample_config["aggregator"].get()) | ||
self.resample_skipna = False # TODO | ||
if resample_config["aggregator"].exists() and resample_config["skipna"].exists(): | ||
self.resample_skipna = resample_config["skipna"].get() | ||
|
||
self.scale = False | ||
if statistic_config["scale"].exists(): | ||
self.scale_func = getattr(np, statistic_config["scale"].get()) | ||
|
||
self.dist = statistic_config["likelihood"]["dist"].get() | ||
|
||
def _forecast_regularize(self, data): | ||
# scale the data so that the lastest X items are more important | ||
last_n = self.regularization_config["last_n"].get() | ||
mult = self.regularization_config["mult"].get() | ||
# multiply the last n items by mult | ||
reg_data = data * np.concatenate([np.ones(data.shape[0]-last_n), np.ones(last_n)*mult]) | ||
return reg_data | ||
|
||
def _allsubpop_regularize(self, data): | ||
""" add a regularization term that is the sum of all subpopulations | ||
""" | ||
pass # TODO | ||
|
||
def __str__(self) -> str: | ||
return f"{self.name}: {self.dist} between {self.sim_var} (sim) and {self.data_var} (data)." | ||
|
||
def __repr__(self) -> str: | ||
return f"A Statistic(): {self.__str__()}" | ||
|
||
def apply_resample(self, data): | ||
if self.resample: | ||
return data.resample(self.resample_freq).agg(self.resample_aggregator, skipna=self.resample_skipna) | ||
else: | ||
return data | ||
|
||
def apply_scale(self, data): | ||
if self.scale: | ||
return self.scale_func(data) | ||
else: | ||
return data | ||
|
||
def apply_transforms(self, data): | ||
data_scaled_resampled = self.apply_scale(self.apply_resample(data)) | ||
# Apply regularizations sequentially | ||
for reg_func, reg_config in self.regularizations: | ||
data_scaled_resampled = reg_func(data_scaled_resampled, **reg_config) # Pass config parameters | ||
return data_scaled_resampled | ||
|
||
|
||
def compute_logloss(self, model_data, gt_data): | ||
model_data = self. apply_transforms(model_data[self.sim_var]) | ||
gt_data = self.apply_transforms(gt_data[self.data_var]) | ||
# TODO: check the order of the arguments | ||
dist_map = { | ||
"pois": scipy.stats.poisson.pmf, | ||
"norm": lambda x, loc, scale: scipy.stats.norm.pdf(x, loc=loc, scale=self.params.get("scale", scale)), # wrong: | ||
"nbinom": lambda x, n, p: scipy.stats.nbinom.pmf(x, n=self.params.get("n"), p=model_data), | ||
"rmse": lambda x, y: np.sqrt(np.mean((x-y)**2)), | ||
"absolute_error": lambda x, y: np.mean(np.abs(x-y)), | ||
} | ||
if self.dist not in dist_map: | ||
raise ValueError(f"Invalid distribution specified: {self.dist}") | ||
|
||
# Use stored parameters in the distribution function call | ||
likelihood = dist_map[self.dist](gt_data, model_data, **self.params) | ||
|
||
if not model_data.shape == gt_data.shape: | ||
raise ValueError(f"{self.name} Statistic error: data and groundtruth do not have the same shape") | ||
|
||
return np.log(likelihood) |