Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lin reg: only warn on superfluous predictors #361

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mathause
Copy link
Member

  • Closes #xxx
  • Tests added
  • Passes isort . && black . && flake8
  • Fully documented, including CHANGELOG.rst

Just had a use case for this. Not sure if a good idea or not.

Copy link

codecov bot commented Dec 19, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (c0069eb) 87.94% compared to head (3a39bf5) 87.95%.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #361   +/-   ##
=======================================
  Coverage   87.94%   87.95%           
=======================================
  Files          40       40           
  Lines        1751     1752    +1     
=======================================
+ Hits         1540     1541    +1     
  Misses        211      211           
Flag Coverage Δ
unittests 87.95% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@veni-vidi-vici-dormivi
Copy link
Collaborator

veni-vidi-vici-dormivi commented Nov 8, 2024

I would like to suggest taking into account exclude also for the available predictors. So that you cannot only exclude predictor vars from the params but also from the predictors Dataset. This is because you might not want to change anything about the original predictor set but still only use part of them. This makes more sense when you DataTree or Dataset as a datastructure because then It might be less work to exclude a predictor from a Dataset instead of making two Datasets.

def predict(
        self,
        predictors: dict[str, xr.DataArray] | DataTree | xr.Dataset,
        exclude=None,
    ) -> xr.DataArray:
        """
        Predict using the linear model.

        Parameters
        ----------
        predictors : dict of xr.DataArray | DataTree | xr.Dataset
            A dict of ``DataArray`` objects used as predictors or a ``DataTree``, holding each
            predictor in a leaf. Each predictor must be 1D and contain ``dim``. If predictors
            is a ``xr.Dataset``, it must have each predictor as a ``DataArray``.
        exclude : str or set of str, default: None
            Set of variables to exclude in the prediction. May include ``"intercept"``
            to initialize the prediction with 0.

        Returns
        -------
        prediction : xr.DataArray
            Returns predicted values.
        """

        params = self.params

        exclude = _to_set(exclude)

        non_predictor_vars = {"intercept", "weights", "fit_intercept"}
        required_predictors = set(params.data_vars) - non_predictor_vars - exclude
        available_predictors = set(predictors.keys()) - exclude

        if required_predictors - available_predictors:
            missing = sorted(required_predictors - available_predictors)
            missing = "', '".join(missing)
            raise ValueError(f"Missing predictors: '{missing}'")

        if available_predictors - required_predictors:
            superfluous = sorted(available_predictors - required_predictors)
            superfluous = "', '".join(superfluous)
            raise ValueError(
                f"Superfluous predictors: '{superfluous}', either params",
                "for this predictor are missing or you forgot to add it to 'exclude'.",
            )

        if "intercept" in exclude:
            prediction = xr.zeros_like(params.intercept)
        else:
            prediction = params.intercept

        # if predictors is a DataTree, rename all data variables to "pred" to avoid conflicts
        if isinstance(predictors, DataTree) and not predictors.equals(DataTree()):
            predictors = map_over_subtree(
                lambda ds: ds.rename({var: "pred" for var in ds.data_vars})
            )(predictors)

        for key in required_predictors:
            prediction = (predictors[key] * params[key]).transpose() + prediction

        prediction = (
            _extract_single_dataarray_from_dt(prediction)
            if isinstance(prediction, DataTree)
            else prediction
        )

        return prediction.rename("prediction")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants