From d87c2789fa95ea121708d553ac02309e984472eb Mon Sep 17 00:00:00 2001 From: Georges Dupret Date: Tue, 25 May 2021 14:20:17 -0700 Subject: [PATCH 1/3] allow excluding features from the Boruta test --- .gitignore | 1 + boruta/boruta_py.py | 24 +++++--- boruta/test/unit_tests.py | 117 +++++++++++++++++++++++++++++++------- 3 files changed, 112 insertions(+), 30 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f11b75 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea/ diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 4eff5ac..9e084b0 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -9,6 +9,8 @@ """ from __future__ import print_function, division + + import numpy as np import scipy as sp from sklearn.utils import check_random_state, check_X_y @@ -206,19 +208,24 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05, self.__version__ = '0.3' self._is_lightgbm = 'lightgbm' in str(type(self.estimator)) - def fit(self, X, y): + def fit(self, X, y, C=None): """ - Fits the Boruta feature selection with the provided estimator. + Fits the Boruta feature selection with the provided estimator. X contains the features the importance of which + we want to test. C contains all the features we know are important or that we don't want to test Parameters ---------- - X : array-like, shape = [n_samples, n_features] - The training input samples. + X : array-like, shape = [n_samples, n_test_features] + The training input samples columns to be tested y : array-like, shape = [n_samples] The target values. + + C : array-line, shape = [n_samples, n_secured_features] + The training input samples columns we don't test """ + self.C = C return self._fit(X, y) def transform(self, X, weak=False, return_df=False): @@ -461,7 +468,7 @@ def _get_tree_num(self, n_feat): "The estimator does not have a max_depth property, as a result " " the number of trees to use cannot be estimated automatically." ) - if depth == None: + if depth is None: depth = 10 # how many times a feature should be considered on average f_repr = 100 @@ -471,13 +478,14 @@ def _get_tree_num(self, n_feat): return n_estimators def _get_imp(self, X, y): + data = X if self.C is None else np.hstack((X, self.C)) try: - self.estimator.fit(X, y) + self.estimator.fit(data, y) except Exception as e: raise ValueError('Please check your X and y variable. The provided ' 'estimator cannot be fitted to your data.\n' + str(e)) try: - imp = self.estimator.feature_importances_ + imp = self.estimator.feature_importances_[:X.shape[1]] except Exception: raise ValueError('Only methods with feature_importance_ attribute ' 'are currently supported in BorutaPy.') @@ -495,7 +503,7 @@ def _add_shadows_get_imps(self, X, y, dec_reg): # deep copy the matrix for the shadow matrix x_sha = np.copy(x_cur) # make sure there's at least 5 columns in the shadow matrix for - while (x_sha.shape[1] < 5): + while x_sha.shape[1] < 5: x_sha = np.hstack((x_sha, x_sha)) # shuffle xSha x_sha = np.apply_along_axis(self._get_shuffle, 0, x_sha) diff --git a/boruta/test/unit_tests.py b/boruta/test/unit_tests.py index 5d5ce9f..4b7a8fe 100644 --- a/boruta/test/unit_tests.py +++ b/boruta/test/unit_tests.py @@ -3,6 +3,70 @@ import pandas as pd from sklearn.ensemble import RandomForestClassifier import numpy as np +import shap +import xgboost as xgb + +xgboost_parameters = { + "alpha": 0.0, + "colsample_bylevel": 1.0, + "colsample_bytree": 1.0, + "eta": 0.3, + "eval_metric": ["error"], + "gamma": 0.0, + "lambda": 1.0, + "max_bin": 256, + "max_delta_step": 0, + "max_depth": 6, + "min_child_weight": 1, + "nthread": -1, + "objective": "binary:logistic", + "subsample": 1.0, + "tree_method": "auto" +} + + +class Learner: + + def __init__(self, estimator): + self.estimator = estimator + self.explainer = None + self.feature_importances_ = None + + def set_params(self, n_estimators=1000, random_state=None): + self.feature_importances_ = None + self.estimator.set_params(n_estimators=n_estimators, random_state=random_state) + + def get_params(self): + return self.estimator.get_params() + + def fit(self, X, y): + self.estimator.fit(X, y) + self.explainer = shap.TreeExplainer(self.estimator) + self.feature_importances_ = np.absolute(self.explainer.shap_values(X)).sum(axis=0) + + +def create_data(): + y = np.random.binomial(1, 0.5, 1000) + X = np.zeros((1000, 10)) + + z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000) + z[z == -1] = 0 + z[z == 2] = 1 + + # 5 relevant features + X[:, 0] = z + X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000) + X[:, 2] = y + np.random.normal(0, 1, 1000) + X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000) + X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) + + # 5 irrelevant features + X[:, 5] = np.random.normal(0, 1, 1000) + X[:, 6] = np.random.poisson(1, 1000) + X[:, 7] = np.random.binomial(1, 0.3, 1000) + X[:, 8] = np.random.normal(0, 1, 1000) + X[:, 9] = np.random.poisson(1, 1000) + return X, y class BorutaTestCases(unittest.TestCase): @@ -15,26 +79,7 @@ def test_get_tree_num(self): def test_if_boruta_extracts_relevant_features(self): np.random.seed(42) - y = np.random.binomial(1, 0.5, 1000) - X = np.zeros((1000, 10)) - - z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000) - z[z == -1] = 0 - z[z == 2] = 1 - - # 5 relevant features - X[:, 0] = z - X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000) - X[:, 2] = y + np.random.normal(0, 1, 1000) - X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000) - X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) - - # 5 irrelevant features - X[:, 5] = np.random.normal(0, 1, 1000) - X[:, 6] = np.random.poisson(1, 1000) - X[:, 7] = np.random.binomial(1, 0.3, 1000) - X[:, 8] = np.random.normal(0, 1, 1000) - X[:, 9] = np.random.poisson(1, 1000) + X, y = create_data() rfc = RandomForestClassifier() bt = BorutaPy(rfc) @@ -51,7 +96,35 @@ def test_if_boruta_extracts_relevant_features(self): # check it dataframe is returned when return_df=True self.assertIsInstance(bt.transform(X_df, return_df=True), pd.DataFrame) -if __name__ == '__main__': - unittest.main() + def test_xgboost_version(self): + np.random.seed(42) + X, y = create_data() + + bst = xgb.XGBRFRegressor(tree_method="hist", max_depth=5, n_estimators=10) + bt = BorutaPy(bst, n_estimators=bst.n_estimators) + bt.fit(X, y) + + explainer = shap.TreeExplainer(bst) + shap_values = explainer.shap_values(X) + self.assertEqual(shap_values.shape, X.shape) + + def test_xgboost_shapley(self): + np.random.seed(42) + # training data + X, y = create_data() + C = X[:, :3] # features that are known to be important + T = X[:, 3:] # features to test -- only the first two in T should turn out to be important + # Learner + bst = Learner(xgb.XGBRFRegressor(**xgboost_parameters, n_estimators=10)) + + # Boruta + bt = BorutaPy(bst, n_estimators=bst.get_params()['n_estimators']) + bt.fit(T, y, C) + + self.assertListEqual(list(range(2)), list(np.where(bt.support_)[0])) + + +if __name__ == '__main__': + unittest.main() From fdaf0fe833570669000dac7973a02ac7e54ed550 Mon Sep 17 00:00:00 2001 From: Georges Dupret Date: Tue, 25 May 2021 14:40:44 -0700 Subject: [PATCH 2/3] sync pazuzu.duckdns.org --- boruta/boruta_py.py | 1 + boruta/test/unit_tests.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 9e084b0..9504914 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -366,6 +366,7 @@ def _fit(self, X, y): # based on hit_reg we check if a feature is doing better than # expected by chance dec_reg = self._do_tests(dec_reg, hit_reg, _iter) + print(f"{dec_reg}") # print out confirmed features if self.verbose > 0 and _iter < self.max_iter: diff --git a/boruta/test/unit_tests.py b/boruta/test/unit_tests.py index 4b7a8fe..a4697a3 100644 --- a/boruta/test/unit_tests.py +++ b/boruta/test/unit_tests.py @@ -96,17 +96,16 @@ def test_if_boruta_extracts_relevant_features(self): # check it dataframe is returned when return_df=True self.assertIsInstance(bt.transform(X_df, return_df=True), pd.DataFrame) - def test_xgboost_version(self): + def test_xgboost_default(self): np.random.seed(42) X, y = create_data() - bst = xgb.XGBRFRegressor(tree_method="hist", max_depth=5, n_estimators=10) + bst = xgb.XGBRFRegressor(**xgboost_parameters) bt = BorutaPy(bst, n_estimators=bst.n_estimators) bt.fit(X, y) - explainer = shap.TreeExplainer(bst) - shap_values = explainer.shap_values(X) - self.assertEqual(shap_values.shape, X.shape) + # make sure that only all the relevant features are returned + self.assertListEqual(list(range(5)), list(np.where(bt.support_)[0])) def test_xgboost_shapley(self): np.random.seed(42) From 91a85afc3a31a86b433240d0716c90d36dbc75eb Mon Sep 17 00:00:00 2001 From: Georges Dupret Date: Tue, 25 May 2021 17:24:54 -0700 Subject: [PATCH 3/3] - specify which features should be tested for importance - add a wrapper to xgboost to use Shap to measure importance --- boruta/boruta_py.py | 17 +++++++--------- boruta/test/unit_tests.py | 41 ++++++++++++++++++++++++--------------- requirements.txt | 8 ++++++++ 3 files changed, 40 insertions(+), 26 deletions(-) create mode 100644 requirements.txt diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 9504914..1b65610 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -9,8 +9,6 @@ """ from __future__ import print_function, division - - import numpy as np import scipy as sp from sklearn.utils import check_random_state, check_X_y @@ -323,12 +321,12 @@ def _fit(self, X, y): # 0 - default state = tentative in original code # 1 - accepted in original code # -1 - rejected in original code - dec_reg = np.zeros(n_feat, dtype=np.int) + dec_reg = np.zeros(n_feat, dtype=int) # counts how many times a given feature was more important than # the best of the shadow features - hit_reg = np.zeros(n_feat, dtype=np.int) + hit_reg = np.zeros(n_feat, dtype=int) # these record the history of the iterations - imp_history = np.zeros(n_feat, dtype=np.float) + imp_history = np.zeros(n_feat, dtype=float) sha_max_history = [] # set n_estimators @@ -366,7 +364,6 @@ def _fit(self, X, y): # based on hit_reg we check if a feature is doing better than # expected by chance dec_reg = self._do_tests(dec_reg, hit_reg, _iter) - print(f"{dec_reg}") # print out confirmed features if self.verbose > 0 and _iter < self.max_iter: @@ -401,13 +398,13 @@ def _fit(self, X, y): # basic result variables self.n_features_ = confirmed.shape[0] - self.support_ = np.zeros(n_feat, dtype=np.bool) + self.support_ = np.zeros(n_feat, dtype=bool) self.support_[confirmed] = 1 - self.support_weak_ = np.zeros(n_feat, dtype=np.bool) + self.support_weak_ = np.zeros(n_feat, dtype=bool) self.support_weak_[tentative] = 1 # ranking, confirmed variables are rank 1 - self.ranking_ = np.ones(n_feat, dtype=np.int) + self.ranking_ = np.ones(n_feat, dtype=int) # tentative variables are rank 2 self.ranking_[tentative] = 2 # selected = confirmed and tentative @@ -433,7 +430,7 @@ def _fit(self, X, y): self.ranking_[not_selected] = ranks else: # all are selected, thus we set feature supports to True - self.support_ = np.ones(n_feat, dtype=np.bool) + self.support_ = np.ones(n_feat, dtype=bool) self.importance_history_ = imp_history diff --git a/boruta/test/unit_tests.py b/boruta/test/unit_tests.py index a4697a3..2a34608 100644 --- a/boruta/test/unit_tests.py +++ b/boruta/test/unit_tests.py @@ -27,22 +27,31 @@ class Learner: - def __init__(self, estimator): - self.estimator = estimator - self.explainer = None + def __init__(self, params, nrounds=1000, verbose=False): + self.params = params + self.nrounds = nrounds self.feature_importances_ = None - - def set_params(self, n_estimators=1000, random_state=None): + self.verbose = verbose + + def set_params(self, n_estimators=None, random_state=None): + """ + used by boruta_py but essentially useless in the case of xgboost. + :param n_estimators: the number of rounds, typically hard set in xgboost + :param random_state: ignored + """ self.feature_importances_ = None - self.estimator.set_params(n_estimators=n_estimators, random_state=random_state) + if n_estimators: + self.nrounds = n_estimators def get_params(self): - return self.estimator.get_params() + return self.params def fit(self, X, y): - self.estimator.fit(X, y) - self.explainer = shap.TreeExplainer(self.estimator) - self.feature_importances_ = np.absolute(self.explainer.shap_values(X)).sum(axis=0) + dtrain = xgb.DMatrix(X, label=y) + eval_set = [(dtrain, 'test')] + model = xgb.train(self.params, dtrain, num_boost_round=self.nrounds, evals=eval_set, verbose_eval=self.verbose) + explainer = shap.TreeExplainer(model) + self.feature_importances_ = np.absolute(explainer.shap_values(X)).sum(axis=0) def create_data(): @@ -96,18 +105,18 @@ def test_if_boruta_extracts_relevant_features(self): # check it dataframe is returned when return_df=True self.assertIsInstance(bt.transform(X_df, return_df=True), pd.DataFrame) - def test_xgboost_default(self): + def test_xgboost_all_features(self): np.random.seed(42) X, y = create_data() - bst = xgb.XGBRFRegressor(**xgboost_parameters) - bt = BorutaPy(bst, n_estimators=bst.n_estimators) + bst = Learner(xgboost_parameters, nrounds=10) + bt = BorutaPy(bst, n_estimators=bst.nrounds, verbose=True) bt.fit(X, y) # make sure that only all the relevant features are returned self.assertListEqual(list(range(5)), list(np.where(bt.support_)[0])) - def test_xgboost_shapley(self): + def test_xgboost_some_features(self): np.random.seed(42) # training data @@ -116,10 +125,10 @@ def test_xgboost_shapley(self): T = X[:, 3:] # features to test -- only the first two in T should turn out to be important # Learner - bst = Learner(xgb.XGBRFRegressor(**xgboost_parameters, n_estimators=10)) + bst = Learner(xgboost_parameters, nrounds=25) # Boruta - bt = BorutaPy(bst, n_estimators=bst.get_params()['n_estimators']) + bt = BorutaPy(bst, n_estimators=bst.nrounds, max_iter=10, verbose=True) bt.fit(T, y, C) self.assertListEqual(list(range(2)), list(np.where(bt.support_)[0])) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0d3c61d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +pandas~=1.2.3 +numpy~=1.20.1 +shap~=0.39.0 +xgboost~=1.3.3 +sklearn~=0.0 +scikit-learn~=0.24.1 +scipy~=1.6.2 +setuptools~=53.1.0 \ No newline at end of file