Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a statistical model fitted to the original dataset to synthesize data #179

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pyrealm.egg-info

# Data
pyrealm_build_data/inputs_data_24.25.nc
pyrealm_build_data/eda.py

# Profiling
prof/
4,096 changes: 1,985 additions & 2,111 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ flake8 = "^4.0.1"
flake8-docstrings = "^1.6.0"
mypy = "^0.991"
isort = "^5.12.0"
pandas = ">1.3.0"
pandas = "^2.2.0"
matplotlib = "^3.5.2"
ipython = "^8.9.0"

Expand Down
Binary file added pyrealm_build_data/data_model.nc
Binary file not shown.
81 changes: 81 additions & 0 deletions pyrealm_build_data/synth_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""This script uses a parametrized model to compress the input dataset.

It fits a time series model to the input data and stores the model parameters.
The dataset can then be reconstructed from the model parameters using the `reconstruct`
function, provided with a custom time index.
"""
tztsai marked this conversation as resolved.
Show resolved Hide resolved
from typing import Tuple

import numpy as np
import pandas as pd
import xarray as xr

VAR_BOUNDS = dict(
temp=(-25, 80),
patm=(3e4, 11e4),
vpd=(0, 1e4),
co2=(0, 1e3),
fapar=(0, 1),
ppfd=(0, 1e4),
)
tztsai marked this conversation as resolved.
Show resolved Hide resolved


def make_time_features(t: np.ndarray) -> pd.DataFrame:
"""Make time features for a given time index."""
dt = pd.to_datetime(t).rename("time")
df = pd.DataFrame(index=dt).assign(const=1.0)
tztsai marked this conversation as resolved.
Show resolved Hide resolved

df["linear"] = (dt - pd.Timestamp("2000-01-01")) / pd.Timedelta("365.25d")

for f in [730.5, 365.25, 12, 6, 4, 3, 2, 1, 1 / 2, 1 / 3, 1 / 4, 1 / 6]:
df[f"freq_{f:.2f}_sin"] = np.sin(2 * np.pi * f * df["linear"])
df[f"freq_{f:.2f}_cos"] = np.cos(2 * np.pi * f * df["linear"])

return df


def fit_ts_model(df: pd.DataFrame, fs: pd.DataFrame) -> Tuple[pd.DataFrame, float]:
"""Fit a time series model to the data."""
df = df.dropna(axis=1, how="all").fillna(df.mean())
Y = df.values # (times, locs)
X = fs.values # (times, feats)
A = np.linalg.pinv(X) @ Y # (feats, locs)
tztsai marked this conversation as resolved.
Show resolved Hide resolved
loss = np.mean((X @ A - Y) ** 2) / np.var(Y)
pars = pd.DataFrame(A.T, index=df.columns, columns=fs.columns)
return pars, loss


def reconstruct(ds: xr.Dataset, dt: np.ndarray | pd.DatetimeIndex) -> xr.Dataset:
"""Reconstruct the full dataset from the model parameters."""
x = make_time_features(dt).to_xarray().to_dataarray()
ds = xr.Dataset({k: a @ x for k, a in ds.items()})
ds = xr.Dataset({k: a.clip(*VAR_BOUNDS[k]) for k, a in ds.items()})
return ds


if __name__ == "__main__":
ds = xr.open_dataset("pyrealm_build_data/inputs_data_24.25.nc")

mask = ~ds.isnull().all("time").to_dataarray().any("variable")
ds = ds.where(mask, drop=True)

special_time_features = dict(
patm=["const"],
co2=["const", "linear"],
)

features = make_time_features(ds.time)
model = xr.Dataset()

for k in ds.data_vars:
print("Fitting", k)
da = ds[k].isel(time=slice(None, None, 4)) # downsample along time
df = da.to_series().unstack("time").T # (datetimes, locations)
fs = features.loc[df.index] # (datetimes, features)
fs = fs[special_time_features.get(k, fs.columns)]
ps, r = fit_ts_model(df, fs) # (locations, features)
print("Loss:", r)
ps[features.keys().difference(ps.columns)] = 0.0
model[k] = ps.to_xarray().to_dataarray()

model.to_netcdf("pyrealm_build_data/data_model.nc")
tztsai marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 46 additions & 0 deletions tests/regression/data/test_synth_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Test the quality of the synthetic data generated from the model parameters."""

import numpy as np
import pytest
import xarray as xr

try:
DATASET = xr.open_dataset("pyrealm_build_data/inputs_data_24.25.nc")
VARS = DATASET.data_vars
except ValueError:
pytest.skip("Original LFS dataset not checked out.", allow_module_level=True)


def r2_score(y_true: xr.DataArray, y_pred: xr.DataArray) -> float:
"""Compute the R2 score."""
SSE = ((y_true - y_pred) ** 2).sum()
SST = ((y_true - y_true.mean()) ** 2).sum()
return 1 - SSE / SST


@pytest.fixture
def syndata(modelpath="pyrealm_build_data/data_model.nc"):
"""The synthetic dataset."""
from pyrealm_build_data.synth_data import reconstruct

model = xr.open_dataset(modelpath)
ts = xr.date_range("2012-01-01", "2018-01-01", freq="12h")
return reconstruct(model, ts)


@pytest.fixture
def dataset(syndata):
"""The original dataset."""
return DATASET.sel(time=syndata.time)


@pytest.mark.parametrize("var", VARS)
def test_synth_data_quality(dataset, syndata, var):
"""Test the quality of the synthetic data."""
times = syndata.time[np.random.choice(syndata.time.size, 1000, replace=False)]
lats = syndata.lat[np.random.choice(syndata.lat.size, 100, replace=False)]
t = dataset[var].sel(lat=lats, time=times)
p = syndata[var].sel(lat=lats, time=times)
s = r2_score(t, p)
print(f"R2 score for {var} is {s:.2f}")
assert s > 0.85
Loading