Skip to content

Commit

Permalink
Merge branch 'new_inference' into time2date
Browse files Browse the repository at this point in the history
  • Loading branch information
jcblemai authored Apr 16, 2024
2 parents 72c1f4a + 5888d57 commit f60e7e9
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
56 changes: 56 additions & 0 deletions flepimop/gempyor_pkg/src/gempyor/logloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import xarray as xr
import pandas as pd
import numpy as np
import confuse
import scipy.stats
from . import statistics


## https://docs.xarray.dev/en/stable/user-guide/indexing.html#assigning-values-with-indexing
# TODO: add an autatic test that show that the loss is biggest when gt == modeldata

class LogLoss:
def __init__(self, inference_config: confuse.ConfigView, data_dir:str = "."):
self.gt = pd.read_csv(f"{data_dir}/{inference_config['gt_data_path'].get()}")
self.gt["date"] = pd.to_datetime(self.gt['date'])
self.gt = self.gt.set_index("date")
self.statistics = {}
for key, value in inference_config["statistics"].items():
self.statistics[key] = statistics.Statistic(key, value)

def plot_gt(self, ax, subpop, statistic, **kwargs):
self.gt[self.gt["subpop"] == subpop].plot(y=statistic, ax=ax, **kwargs)


def compute_logloss(self, model_df, modinf):
"""
Compute logloss for all statistics
model_df: DataFrame indexed by date
modinf: model information
TODO: support kwargs for emcee, and this looks very slow
"""
logloss = xr.DataArray(0, dims=["statistic", "subpop"],
coords={
"statistic":self.statistics.key(),
"subpop":modinf.subpop_struct.subpop_names})

for subpop in modinf.subpop_struct.subpop_names:
# essential to sort by index (date here)
gt_s = self.gt[self.gt["subpop"] == subpop].sort_index()
model_df_s = model_df[model_df["subpop"] == subpop].sort_index()

# Places where data and model overlap
first_date = max(gt_s.index.min(), model_df_s.index.min())
last_date = min(gt_s.index.max(), model_df_s.index.max())

gt_s = gt_s.loc[first_date:last_date].drop(["subpop"], axis=1)
model_df_s = model_df_s.drop(["subpop"], axis=1).loc[first_date:last_date]

# TODO: add whole US!! option

for key, stat in self.statistics.items():
logloss.loc[dict(statistics=key, subpop=subpop)] += stat.compute_logloss(model_df, gt_s)

return logloss


113 changes: 113 additions & 0 deletions flepimop/gempyor_pkg/src/gempyor/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import xarray as xr
import pandas as pd
import numpy as np
import confuse
import scipy.stats

class Statistic:
"""
A statistic is a function that takes two time series and returns a scalar value.
It applies resample, scale, and regularization to the data before computing the statistic's log-loss.
Configuration:
- sim_var: the variable in the simulation data
- data_var: the variable in the ground truth data
- resample: resample the data before computing the statistic
- freq: the frequency to resample the data to
- aggregator: the aggregation function to use
- skipna: whether to skip NA values
- regularize: apply a regularization term to the data before computing the statistic
"""
def __init__(self, name, statistic_config: confuse.ConfigView):
self.sim_var = statistic_config["sim_var"].as_str()
self.data_var = statistic_config["data_var"].as_str()
self.name = name

self.regularizations = [] # A list to hold regularization functions and configs
if statistic_config["regularize"].exists():
for reg_config in statistic_config["regularize"]: # Iterate over the list
reg_name = reg_config["name"].get()
reg_func = getattr(self, f"_{reg_name}_regularize")
if reg_func is None:
raise ValueError(f"Unsupported regularization: {reg_name}")
self.regularizations.append((reg_func, reg_config))

self.resample = False
if statistic_config["resample"].exists():
self.resample = True
resample_config = statistic_config["resample"]
self.resample_freq = ""
if resample_config["freq"].exists():
self.resample_freq = resample_config["freq"].get()
self.resample_aggregator = ""
if resample_config["aggregator"].exists():
self.resample_aggregator = getattr(pd.Series, resample_config["aggregator"].get())
self.resample_skipna = False # TODO
if resample_config["aggregator"].exists() and resample_config["skipna"].exists():
self.resample_skipna = resample_config["skipna"].get()

self.scale = False
if statistic_config["scale"].exists():
self.scale_func = getattr(np, statistic_config["scale"].get())

self.dist = statistic_config["likelihood"]["dist"].get()

def _forecast_regularize(self, data):
# scale the data so that the lastest X items are more important
last_n = self.regularization_config["last_n"].get()
mult = self.regularization_config["mult"].get()
# multiply the last n items by mult
reg_data = data * np.concatenate([np.ones(data.shape[0]-last_n), np.ones(last_n)*mult])
return reg_data

def _allsubpop_regularize(self, data):
""" add a regularization term that is the sum of all subpopulations
"""
pass # TODO

def __str__(self) -> str:
return f"{self.name}: {self.dist} between {self.sim_var} (sim) and {self.data_var} (data)."

def __repr__(self) -> str:
return f"A Statistic(): {self.__str__()}"

def apply_resample(self, data):
if self.resample:
return data.resample(self.resample_freq).agg(self.resample_aggregator, skipna=self.resample_skipna)
else:
return data

def apply_scale(self, data):
if self.scale:
return self.scale_func(data)
else:
return data

def apply_transforms(self, data):
data_scaled_resampled = self.apply_scale(self.apply_resample(data))
# Apply regularizations sequentially
for reg_func, reg_config in self.regularizations:
data_scaled_resampled = reg_func(data_scaled_resampled, **reg_config) # Pass config parameters
return data_scaled_resampled


def compute_logloss(self, model_data, gt_data):
model_data = self. apply_transforms(model_data[self.sim_var])
gt_data = self.apply_transforms(gt_data[self.data_var])
# TODO: check the order of the arguments
dist_map = {
"pois": scipy.stats.poisson.pmf,
"norm": lambda x, loc, scale: scipy.stats.norm.pdf(x, loc=loc, scale=self.params.get("scale", scale)), # wrong:
"nbinom": lambda x, n, p: scipy.stats.nbinom.pmf(x, n=self.params.get("n"), p=model_data),
"rmse": lambda x, y: np.sqrt(np.mean((x-y)**2)),
"absolute_error": lambda x, y: np.mean(np.abs(x-y)),
}
if self.dist not in dist_map:
raise ValueError(f"Invalid distribution specified: {self.dist}")

# Use stored parameters in the distribution function call
likelihood = dist_map[self.dist](gt_data, model_data, **self.params)

if not model_data.shape == gt_data.shape:
raise ValueError(f"{self.name} Statistic error: data and groundtruth do not have the same shape")

return np.log(likelihood)

0 comments on commit f60e7e9

Please sign in to comment.