Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data tree and dataset formats to linear regression #566

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

veni-vidi-vici-dormivi
Copy link
Collaborator

@veni-vidi-vici-dormivi veni-vidi-vici-dormivi commented Nov 21, 2024

I implement handling of DataTrees and xr.Datasets as predictor formats.

@veni-vidi-vici-dormivi
Copy link
Collaborator Author

Inlcudes #361

Copy link
Member

@mathause mathause left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - some suggestions.

mesmer/stats/_linear_regression.py Outdated Show resolved Hide resolved
mesmer/stats/_linear_regression.py Show resolved Hide resolved

if "intercept" in exclude:
prediction = xr.zeros_like(params.intercept)
else:
prediction = params.intercept

# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
if isinstance(predictors, DataTree) and not predictors.is_empty:

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because is_empty only checks the node, so root, which can be empty while there are other datasets in the datatree. And we still need to check if it is all empty for the test without any predictors.

# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
predictors = map_over_subtree(
lambda ds: ds.rename({var: "pred" for var in ds.data_vars})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we already sure there is only one da on the node? Or will this give a cryptic error message?

Comment on lines +120 to +124
prediction = (
_extract_single_dataarray_from_dt(prediction)
if isinstance(prediction, DataTree)
else prediction
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe don't make this a ternary operation if it does not nicely fit on one line

Suggested change
prediction = (
_extract_single_dataarray_from_dt(prediction)
if isinstance(prediction, DataTree)
else prediction
)
if isinstance(prediction, DataTree):
prediction = _extract_single_dataarray_from_dt(prediction)

@@ -229,14 +260,32 @@ def _fit_linear_regression_xr(
raise ValueError("dim cannot currently be 'predictor'.")

for key, pred in predictors.items():
pred = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de-inline?

Comment on lines +14 to +15
def to_dict(data_dict):
return data_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that doing anything? Maybe add a comment. # no op so all three options have a conversion function (or so)

return DataTree.from_dict(data_dict)


def to_xr_dataset(data_dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to_xr_dataset(data_dict):
def to_dataset(data_dict):

?

@@ -79,24 +93,74 @@ def test_lr_params():


@pytest.mark.parametrize("as_2D", [True, False])
def test_lr_predict(as_2D):
@pytest.mark.parametrize("data_structure", [to_dict, to_datatree, to_xr_dataset])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good approach!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's less clear in the code with the data_structure. Maybe rename to to_data_type?. Or maybe data_type is too close to dtype - could use data_cls?

Alternatively you could write

def convert_to(dct, data_type):
    if data_type == "dict":
        return dct
    ...

# and 
@pytest.mark.parametrize("data_type", ["dict", "datatree", "dataset"])
def test_(...):
    ...

    pred = convert_to({"tas": tas}, data_type)

)
lr.params = params if as_2D else params.squeeze()

tas = xr.DataArray([0, 1, 2], dims="time").rename("tas")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tas = xr.DataArray([0, 1, 2], dims="time").rename("tas")
tas = xr.DataArray([0, 1, 2], dims="time", name="tas")

Copy link

codecov bot commented Nov 21, 2024

Codecov Report

Attention: Patch coverage is 96.42857% with 1 line in your changes missing coverage. Please review.

Project coverage is 77.83%. Comparing base (b80d031) to head (05eb6cc).

Files with missing lines Patch % Lines
mesmer/stats/_linear_regression.py 95.65% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #566      +/-   ##
==========================================
+ Coverage   77.75%   77.83%   +0.08%     
==========================================
  Files          49       49              
  Lines        2967     2978      +11     
==========================================
+ Hits         2307     2318      +11     
  Misses        660      660              
Flag Coverage Δ
unittests 77.83% <96.42%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

stack predictors and targets for xarray objects
2 participants