Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data tree and dataset formats to linear regression #566

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mesmer/create_emulations/create_emus_lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,6 @@ def create_emus_lv_OLS(params_lv, preds_lv):
lr.params = params
prediction = lr.predict(predictors=preds)

emus_lv[scen][targ] = prediction.values
emus_lv[scen][targ] = prediction.values.transpose()

return emus_lv
2 changes: 1 addition & 1 deletion mesmer/create_emulations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _gather_lr_params(params_dict, targ, dims):
intercept = xr.DataArray(params_dict["intercept"][targ], dims=dims)
fit_intercept = True
else:
intercept = 0
intercept = xr.zeros_like(params[pred])
fit_intercept = False

params["intercept"] = intercept
Expand Down
115 changes: 82 additions & 33 deletions mesmer/stats/_linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from collections.abc import Mapping

import numpy as np
import xarray as xr
from datatree import DataTree, map_over_subtree

from mesmer.core.utils import _check_dataarray_form, _check_dataset_form, _to_set
from mesmer.core.datatree import (
_extract_single_dataarray_from_dt,
collapse_datatree_into_dataset,
)
from mesmer.core.utils import (
_check_dataarray_form,
_check_dataset_form,
_to_set,
)


class LinearRegression:
Expand All @@ -14,7 +21,7 @@ def __init__(self):

def fit(
self,
predictors: Mapping[str, xr.DataArray],
predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
target: xr.DataArray,
dim: str,
weights: xr.DataArray | None = None,
Expand All @@ -25,9 +32,10 @@ def fit(

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain
`dim`.
predictors : dict of xr.DataArray | DataTree | xr.Dataset
A dict of DataArray objects used as predictors or a DataTree, holding each
predictor in a leaf. Each predictor must be 1D and contain `dim`. If predictors
is a xr.Dataset, it must have each predictor as a DataArray.
target : xr.DataArray
Target DataArray. Must be 2D and contain `dim`.
dim : str
Expand All @@ -52,16 +60,18 @@ def fit(

def predict(
self,
predictors: Mapping[str, xr.DataArray],
predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
exclude=None,
):
) -> xr.DataArray:
"""
Predict using the linear model.

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.
predictors : dict of xr.DataArray | DataTree | xr.Dataset
A dict of ``DataArray`` objects used as predictors or a ``DataTree``, holding each
predictor in a leaf. Each predictor must be 1D and contain ``dim``. If predictors
is a ``xr.Dataset``, it must have each predictor as a single ``DataArray``.
exclude : str or set of str, default: None
Set of variables to exclude in the prediction. May include ``"intercept"``
to initialize the prediction with 0.
Expand All @@ -78,7 +88,7 @@ def predict(

non_predictor_vars = {"intercept", "weights", "fit_intercept"}
required_predictors = set(params.data_vars) - non_predictor_vars - exclude
available_predictors = set(predictors.keys())
available_predictors = set(predictors.keys()) - exclude

if required_predictors - available_predictors:
missing = sorted(required_predictors - available_predictors)
Expand All @@ -88,30 +98,47 @@ def predict(
if available_predictors - required_predictors:
superfluous = sorted(available_predictors - required_predictors)
superfluous = "', '".join(superfluous)
raise ValueError(f"Superfluous predictors: '{superfluous}'")
raise ValueError(
f"Superfluous predictors: '{superfluous}', either params",
"for this predictor are missing or you forgot to add it to 'exclude'.",
)
veni-vidi-vici-dormivi marked this conversation as resolved.
Show resolved Hide resolved

if "intercept" in exclude:
prediction = xr.zeros_like(params.intercept)
else:
prediction = params.intercept

# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
predictors = map_over_subtree(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
if isinstance(predictors, DataTree) and not predictors.is_empty:

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because is_empty only checks the node, so root, which can be empty while there are other datasets in the datatree. And we still need to check if it is all empty for the test without any predictors.

lambda ds: ds.rename({var: "pred" for var in ds.data_vars})
)(predictors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we already sure there is only one da on the node? Or will this give a cryptic error message?


for key in required_predictors:
prediction = prediction + predictors[key] * params[key]
prediction = (predictors[key] * params[key]).transpose() + prediction

prediction = (
_extract_single_dataarray_from_dt(prediction)
if isinstance(prediction, DataTree)
else prediction
)

Comment on lines +120 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe don't make this a ternary operation if it does not nicely fit on one line

Suggested change
prediction = (
_extract_single_dataarray_from_dt(prediction)
if isinstance(prediction, DataTree)
else prediction
)
if isinstance(prediction, DataTree):
prediction = _extract_single_dataarray_from_dt(prediction)

return prediction
return prediction.rename("prediction")

def residuals(
self,
predictors: Mapping[str, xr.DataArray],
predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
target: xr.DataArray,
):
) -> xr.DataArray:
"""
Calculate the residuals of the fitted linear model

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.
predictors : dict of xr.DataArray | DataTree | xr.Dataset
A dict of DataArray objects used as predictors or a DataTree, holding each
predictor in a leaf. Each predictor must be 1D and contain `dim`. If predictors
is a xr.Dataset, it must have each predictor as a DataArray.
target : xr.DataArray
Target DataArray. Must be 2D and contain `dim`.

Expand All @@ -126,7 +153,7 @@ def residuals(

residuals = target - prediction

return residuals
return residuals.rename("residuals")

@property
def params(self):
Expand Down Expand Up @@ -182,12 +209,12 @@ def to_netcdf(self, filename, **kwargs):
Additional keyword arguments passed to ``xr.Dataset.to_netcf``
"""

params = self.params()
params = self.params
params.to_netcdf(filename, **kwargs)
veni-vidi-vici-dormivi marked this conversation as resolved.
Show resolved Hide resolved


def _fit_linear_regression_xr(
predictors: Mapping[str, xr.DataArray],
predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
target: xr.DataArray,
dim: str,
weights: xr.DataArray | None = None,
Expand All @@ -198,8 +225,10 @@ def _fit_linear_regression_xr(

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.
predictors : dict of xr.DataArray | DataTree | xr.Dataset
A dict of DataArray objects used as predictors or a DataTree, holding each
predictor in a leaf. Each predictor must be 1D and contain `dim`. If predictors
is a xr.Dataset, it must have each predictor as a DataArray.
target : xr.DataArray
Target DataArray. Must be 2D and contain `dim`.
dim : str
Expand All @@ -217,8 +246,10 @@ def _fit_linear_regression_xr(
individual DataArray.
"""

if not isinstance(predictors, Mapping):
raise TypeError(f"predictors should be a dict, got {type(predictors)}.")
if not isinstance(predictors, dict | DataTree | xr.Dataset):
raise TypeError(
f"predictors should be a dict, DataTree or xr.Dataset, got {type(predictors)}."
)

if ("weights" in predictors) or ("intercept" in predictors):
raise ValueError(
Expand All @@ -229,14 +260,32 @@ def _fit_linear_regression_xr(
raise ValueError("dim cannot currently be 'predictor'.")

for key, pred in predictors.items():
pred = (
_extract_single_dataarray_from_dt(pred)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de-inline?

if isinstance(pred, DataTree)
else pred
)
_check_dataarray_form(pred, ndim=1, required_dims=dim, name=f"predictor: {key}")

predictors_concat = xr.concat(
tuple(predictors.values()),
dim="predictor",
join="exact",
coords="minimal",
)
if isinstance(predictors, dict | xr.Dataset):
predictors_concat = xr.concat(
tuple(predictors.values()),
dim="predictor",
join="exact",
coords="minimal",
)
predictors_concat = predictors_concat.assign_coords(
{"predictor": list(predictors.keys())}
)
elif isinstance(predictors, DataTree):
# rename all data variables to "pred" to avoid conflicts when concatenating
predictors = map_over_subtree(
lambda ds: ds.rename({var: "pred" for var in ds.data_vars})
)(predictors)
predictors_concat = collapse_datatree_into_dataset(predictors, dim="predictor")
predictors_concat = (
predictors_concat.to_array().isel(variable=0).drop_vars("variable")
)

_check_dataarray_form(target, required_dims=dim, name="target")

Expand Down Expand Up @@ -267,7 +316,7 @@ def _fit_linear_regression_xr(
target = target.drop_vars(target[dim].coords)

# split `out` into individual DataArrays
keys = ["intercept"] + list(predictors)
keys = ["intercept"] + list(predictors_concat.coords["predictor"].values)
data_vars = {key: (target_dim, out[:, i]) for i, key in enumerate(keys)}
out = xr.Dataset(data_vars, coords=target.coords)

Expand Down
10 changes: 7 additions & 3 deletions mesmer/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def assert_dict_allclose(first, second, first_name="left", second_name="right"):
assert first_val == second_val, key


def trend_data_1D(n_timesteps=30, intercept=0, slope=1, scale=1):
def trend_data_1D(n_timesteps=30, intercept=0.0, slope=1.0, scale=1.0):

time = np.arange(n_timesteps)

Expand All @@ -89,7 +89,9 @@ def trend_data_1D(n_timesteps=30, intercept=0, slope=1, scale=1):
return xr.DataArray(data, dims=("time"), coords={"time": time}, name="data")


def trend_data_2D(n_timesteps=30, n_lat=3, n_lon=2, intercept=0, slope=1, scale=1):
def trend_data_2D(
n_timesteps=30, n_lat=3, n_lon=2, intercept=0.0, slope=1.0, scale=1.0
):

n_cells = n_lat * n_lon
time = np.arange(n_timesteps)
Expand All @@ -109,7 +111,9 @@ def trend_data_2D(n_timesteps=30, n_lat=3, n_lon=2, intercept=0, slope=1, scale=
return xr.DataArray(data, dims=("cells", "time"), coords=coords, name="data")


def trend_data_3D(n_timesteps=30, n_lat=3, n_lon=2, intercept=0, slope=1, scale=1):
def trend_data_3D(
n_timesteps=30, n_lat=3, n_lon=2, intercept=0.0, slope=1.0, scale=1.0
):

data = trend_data_2D(
n_timesteps=n_timesteps,
Expand Down
Loading
Loading