diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f11b75 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea/ diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 4eff5ac..1b65610 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -206,19 +206,24 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05, self.__version__ = '0.3' self._is_lightgbm = 'lightgbm' in str(type(self.estimator)) - def fit(self, X, y): + def fit(self, X, y, C=None): """ - Fits the Boruta feature selection with the provided estimator. + Fits the Boruta feature selection with the provided estimator. X contains the features the importance of which + we want to test. C contains all the features we know are important or that we don't want to test Parameters ---------- - X : array-like, shape = [n_samples, n_features] - The training input samples. + X : array-like, shape = [n_samples, n_test_features] + The training input samples columns to be tested y : array-like, shape = [n_samples] The target values. + + C : array-line, shape = [n_samples, n_secured_features] + The training input samples columns we don't test """ + self.C = C return self._fit(X, y) def transform(self, X, weak=False, return_df=False): @@ -316,12 +321,12 @@ def _fit(self, X, y): # 0 - default state = tentative in original code # 1 - accepted in original code # -1 - rejected in original code - dec_reg = np.zeros(n_feat, dtype=np.int) + dec_reg = np.zeros(n_feat, dtype=int) # counts how many times a given feature was more important than # the best of the shadow features - hit_reg = np.zeros(n_feat, dtype=np.int) + hit_reg = np.zeros(n_feat, dtype=int) # these record the history of the iterations - imp_history = np.zeros(n_feat, dtype=np.float) + imp_history = np.zeros(n_feat, dtype=float) sha_max_history = [] # set n_estimators @@ -393,13 +398,13 @@ def _fit(self, X, y): # basic result variables self.n_features_ = confirmed.shape[0] - self.support_ = np.zeros(n_feat, dtype=np.bool) + self.support_ = np.zeros(n_feat, dtype=bool) self.support_[confirmed] = 1 - self.support_weak_ = np.zeros(n_feat, dtype=np.bool) + self.support_weak_ = np.zeros(n_feat, dtype=bool) self.support_weak_[tentative] = 1 # ranking, confirmed variables are rank 1 - self.ranking_ = np.ones(n_feat, dtype=np.int) + self.ranking_ = np.ones(n_feat, dtype=int) # tentative variables are rank 2 self.ranking_[tentative] = 2 # selected = confirmed and tentative @@ -425,7 +430,7 @@ def _fit(self, X, y): self.ranking_[not_selected] = ranks else: # all are selected, thus we set feature supports to True - self.support_ = np.ones(n_feat, dtype=np.bool) + self.support_ = np.ones(n_feat, dtype=bool) self.importance_history_ = imp_history @@ -461,7 +466,7 @@ def _get_tree_num(self, n_feat): "The estimator does not have a max_depth property, as a result " " the number of trees to use cannot be estimated automatically." ) - if depth == None: + if depth is None: depth = 10 # how many times a feature should be considered on average f_repr = 100 @@ -471,13 +476,14 @@ def _get_tree_num(self, n_feat): return n_estimators def _get_imp(self, X, y): + data = X if self.C is None else np.hstack((X, self.C)) try: - self.estimator.fit(X, y) + self.estimator.fit(data, y) except Exception as e: raise ValueError('Please check your X and y variable. The provided ' 'estimator cannot be fitted to your data.\n' + str(e)) try: - imp = self.estimator.feature_importances_ + imp = self.estimator.feature_importances_[:X.shape[1]] except Exception: raise ValueError('Only methods with feature_importance_ attribute ' 'are currently supported in BorutaPy.') @@ -495,7 +501,7 @@ def _add_shadows_get_imps(self, X, y, dec_reg): # deep copy the matrix for the shadow matrix x_sha = np.copy(x_cur) # make sure there's at least 5 columns in the shadow matrix for - while (x_sha.shape[1] < 5): + while x_sha.shape[1] < 5: x_sha = np.hstack((x_sha, x_sha)) # shuffle xSha x_sha = np.apply_along_axis(self._get_shuffle, 0, x_sha) diff --git a/boruta/test/unit_tests.py b/boruta/test/unit_tests.py index 5d5ce9f..2a34608 100644 --- a/boruta/test/unit_tests.py +++ b/boruta/test/unit_tests.py @@ -3,6 +3,79 @@ import pandas as pd from sklearn.ensemble import RandomForestClassifier import numpy as np +import shap +import xgboost as xgb + +xgboost_parameters = { + "alpha": 0.0, + "colsample_bylevel": 1.0, + "colsample_bytree": 1.0, + "eta": 0.3, + "eval_metric": ["error"], + "gamma": 0.0, + "lambda": 1.0, + "max_bin": 256, + "max_delta_step": 0, + "max_depth": 6, + "min_child_weight": 1, + "nthread": -1, + "objective": "binary:logistic", + "subsample": 1.0, + "tree_method": "auto" +} + + +class Learner: + + def __init__(self, params, nrounds=1000, verbose=False): + self.params = params + self.nrounds = nrounds + self.feature_importances_ = None + self.verbose = verbose + + def set_params(self, n_estimators=None, random_state=None): + """ + used by boruta_py but essentially useless in the case of xgboost. + :param n_estimators: the number of rounds, typically hard set in xgboost + :param random_state: ignored + """ + self.feature_importances_ = None + if n_estimators: + self.nrounds = n_estimators + + def get_params(self): + return self.params + + def fit(self, X, y): + dtrain = xgb.DMatrix(X, label=y) + eval_set = [(dtrain, 'test')] + model = xgb.train(self.params, dtrain, num_boost_round=self.nrounds, evals=eval_set, verbose_eval=self.verbose) + explainer = shap.TreeExplainer(model) + self.feature_importances_ = np.absolute(explainer.shap_values(X)).sum(axis=0) + + +def create_data(): + y = np.random.binomial(1, 0.5, 1000) + X = np.zeros((1000, 10)) + + z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000) + z[z == -1] = 0 + z[z == 2] = 1 + + # 5 relevant features + X[:, 0] = z + X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000) + X[:, 2] = y + np.random.normal(0, 1, 1000) + X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000) + X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) + + # 5 irrelevant features + X[:, 5] = np.random.normal(0, 1, 1000) + X[:, 6] = np.random.poisson(1, 1000) + X[:, 7] = np.random.binomial(1, 0.3, 1000) + X[:, 8] = np.random.normal(0, 1, 1000) + X[:, 9] = np.random.poisson(1, 1000) + return X, y class BorutaTestCases(unittest.TestCase): @@ -15,26 +88,7 @@ def test_get_tree_num(self): def test_if_boruta_extracts_relevant_features(self): np.random.seed(42) - y = np.random.binomial(1, 0.5, 1000) - X = np.zeros((1000, 10)) - - z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000) - z[z == -1] = 0 - z[z == 2] = 1 - - # 5 relevant features - X[:, 0] = z - X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000) - X[:, 2] = y + np.random.normal(0, 1, 1000) - X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000) - X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) - - # 5 irrelevant features - X[:, 5] = np.random.normal(0, 1, 1000) - X[:, 6] = np.random.poisson(1, 1000) - X[:, 7] = np.random.binomial(1, 0.3, 1000) - X[:, 8] = np.random.normal(0, 1, 1000) - X[:, 9] = np.random.poisson(1, 1000) + X, y = create_data() rfc = RandomForestClassifier() bt = BorutaPy(rfc) @@ -51,7 +105,34 @@ def test_if_boruta_extracts_relevant_features(self): # check it dataframe is returned when return_df=True self.assertIsInstance(bt.transform(X_df, return_df=True), pd.DataFrame) -if __name__ == '__main__': - unittest.main() + def test_xgboost_all_features(self): + np.random.seed(42) + X, y = create_data() + + bst = Learner(xgboost_parameters, nrounds=10) + bt = BorutaPy(bst, n_estimators=bst.nrounds, verbose=True) + bt.fit(X, y) + + # make sure that only all the relevant features are returned + self.assertListEqual(list(range(5)), list(np.where(bt.support_)[0])) + + def test_xgboost_some_features(self): + np.random.seed(42) + + # training data + X, y = create_data() + C = X[:, :3] # features that are known to be important + T = X[:, 3:] # features to test -- only the first two in T should turn out to be important + + # Learner + bst = Learner(xgboost_parameters, nrounds=25) + + # Boruta + bt = BorutaPy(bst, n_estimators=bst.nrounds, max_iter=10, verbose=True) + bt.fit(T, y, C) + self.assertListEqual(list(range(2)), list(np.where(bt.support_)[0])) + +if __name__ == '__main__': + unittest.main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0d3c61d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +pandas~=1.2.3 +numpy~=1.20.1 +shap~=0.39.0 +xgboost~=1.3.3 +sklearn~=0.0 +scikit-learn~=0.24.1 +scipy~=1.6.2 +setuptools~=53.1.0 \ No newline at end of file