-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lin reg: only warn on superfluous predictors #361
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #361 +/- ##
=======================================
Coverage 87.94% 87.95%
=======================================
Files 40 40
Lines 1751 1752 +1
=======================================
+ Hits 1540 1541 +1
Misses 211 211
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
I would like to suggest taking into account def predict(
self,
predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
exclude=None,
) -> xr.DataArray:
"""
Predict using the linear model.
Parameters
----------
predictors : dict of xr.DataArray | DataTree | xr.Dataset
A dict of ``DataArray`` objects used as predictors or a ``DataTree``, holding each
predictor in a leaf. Each predictor must be 1D and contain ``dim``. If predictors
is a ``xr.Dataset``, it must have each predictor as a ``DataArray``.
exclude : str or set of str, default: None
Set of variables to exclude in the prediction. May include ``"intercept"``
to initialize the prediction with 0.
Returns
-------
prediction : xr.DataArray
Returns predicted values.
"""
params = self.params
exclude = _to_set(exclude)
non_predictor_vars = {"intercept", "weights", "fit_intercept"}
required_predictors = set(params.data_vars) - non_predictor_vars - exclude
available_predictors = set(predictors.keys()) - exclude
if required_predictors - available_predictors:
missing = sorted(required_predictors - available_predictors)
missing = "', '".join(missing)
raise ValueError(f"Missing predictors: '{missing}'")
if available_predictors - required_predictors:
superfluous = sorted(available_predictors - required_predictors)
superfluous = "', '".join(superfluous)
raise ValueError(
f"Superfluous predictors: '{superfluous}', either params",
"for this predictor are missing or you forgot to add it to 'exclude'.",
)
if "intercept" in exclude:
prediction = xr.zeros_like(params.intercept)
else:
prediction = params.intercept
# if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts
if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
predictors = map_over_subtree(
lambda ds: ds.rename({var: "pred" for var in ds.data_vars})
)(predictors)
for key in required_predictors:
prediction = (predictors[key] * params[key]).transpose() + prediction
prediction = (
_extract_single_dataarray_from_dt(prediction)
if isinstance(prediction, DataTree)
else prediction
)
return prediction.rename("prediction") |
isort . && black . && flake8
CHANGELOG.rst
Just had a use case for this. Not sure if a good idea or not.