-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data tree and dataset formats to linear regression #566
base: main
Are you sure you want to change the base?
Add data tree and dataset formats to linear regression #566
Conversation
Inlcudes #361 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - some suggestions.
|
||
if "intercept" in exclude: | ||
prediction = xr.zeros_like(params.intercept) | ||
else: | ||
prediction = params.intercept | ||
|
||
# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts | ||
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()): | |
if isinstance(predictors, DataTree) and not predictors.is_empty: |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because is_empty only checks the node, so root, which can be empty while there are other datasets in the datatree. And we still need to check if it is all empty for the test without any predictors.
# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts | ||
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()): | ||
predictors = map_over_subtree( | ||
lambda ds: ds.rename({var: "pred" for var in ds.data_vars}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we already sure there is only one da on the node? Or will this give a cryptic error message?
prediction = ( | ||
_extract_single_dataarray_from_dt(prediction) | ||
if isinstance(prediction, DataTree) | ||
else prediction | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe don't make this a ternary operation if it does not nicely fit on one line
prediction = ( | |
_extract_single_dataarray_from_dt(prediction) | |
if isinstance(prediction, DataTree) | |
else prediction | |
) | |
if isinstance(prediction, DataTree): | |
prediction = _extract_single_dataarray_from_dt(prediction) |
@@ -229,14 +260,32 @@ def _fit_linear_regression_xr( | |||
raise ValueError("dim cannot currently be 'predictor'.") | |||
|
|||
for key, pred in predictors.items(): | |||
pred = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
de-inline?
def to_dict(data_dict): | ||
return data_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that doing anything? Maybe add a comment. # no op so all three options have a conversion function
(or so)
return DataTree.from_dict(data_dict) | ||
|
||
|
||
def to_xr_dataset(data_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def to_xr_dataset(data_dict): | |
def to_dataset(data_dict): |
?
@@ -79,24 +93,74 @@ def test_lr_params(): | |||
|
|||
|
|||
@pytest.mark.parametrize("as_2D", [True, False]) | |||
def test_lr_predict(as_2D): | |||
@pytest.mark.parametrize("data_structure", [to_dict, to_datatree, to_xr_dataset]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good approach!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it's less clear in the code with the data_structure
. Maybe rename to to_data_type
?. Or maybe data_type
is too close to dtype
- could use data_cls
?
Alternatively you could write
def convert_to(dct, data_type):
if data_type == "dict":
return dct
...
# and
@pytest.mark.parametrize("data_type", ["dict", "datatree", "dataset"])
def test_(...):
...
pred = convert_to({"tas": tas}, data_type)
) | ||
lr.params = params if as_2D else params.squeeze() | ||
|
||
tas = xr.DataArray([0, 1, 2], dims="time").rename("tas") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tas = xr.DataArray([0, 1, 2], dims="time").rename("tas") | |
tas = xr.DataArray([0, 1, 2], dims="time", name="tas") |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #566 +/- ##
==========================================
+ Coverage 77.75% 77.83% +0.08%
==========================================
Files 49 49
Lines 2967 2978 +11
==========================================
+ Hits 2307 2318 +11
Misses 660 660
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Co-authored-by: Mathias Hauser <[email protected]>
I implement handling of
DataTree
s andxr.Dataset
s as predictor formats.CHANGELOG.rst