From 850dd2e14b8b95c2751ea1c142b276d5b4e32d3d Mon Sep 17 00:00:00 2001 From: zyliang2001 Date: Sat, 13 Jan 2024 22:50:12 -0800 Subject: [PATCH] Update of Lime, etc --- .../two_subgroups_linear_sims/models.py | 5 +- .../local_MDI_plus_visulization.ipynb | 790 ++++++++++++++++++ .../scripts/competing_methods_local.py | 88 +- feature_importance/util.py | 6 +- 4 files changed, 855 insertions(+), 34 deletions(-) diff --git a/feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/models.py b/feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/models.py index 87539d2..bffd190 100644 --- a/feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/models.py +++ b/feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/models.py @@ -1,7 +1,6 @@ from sklearn.ensemble import RandomForestRegressor from feature_importance.util import ModelConfig, FIModelConfig -from feature_importance.scripts.competing_methods_local import tree_shap_local, permutation_local - +from feature_importance.scripts.competing_methods_local import tree_shap_local, permutation_local, lime_local, MDI_local_all_stumps, MDI_local_sub_stumps # N_ESTIMATORS=[50, 100, 500, 1000] ESTIMATORS = [ [ModelConfig('RF', RandomForestRegressor, model_type='tree', @@ -13,4 +12,6 @@ FI_ESTIMATORS = [ [FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')], [FIModelConfig('Permutation', permutation_local, model_type='tree')], + [FIModelConfig('LIME', lime_local, model_type='tree')], + [FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')], ] \ No newline at end of file diff --git a/feature_importance/local_MDI_plus_visulization.ipynb b/feature_importance/local_MDI_plus_visulization.ipynb index 90547ec..e6b03d7 100644 --- a/feature_importance/local_MDI_plus_visulization.ipynb +++ b/feature_importance/local_MDI_plus_visulization.ipynb @@ -10,6 +10,399 @@ "import matplotlib.pyplot as plt" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting limeNote: you may need to restart the kernel to use updated packages.\n", + "\n", + " Downloading lime-0.2.0.1.tar.gz (275 kB)\n", + " ---------------------------------------- 0.0/275.7 kB ? eta -:--:--\n", + " ----------------------- -------------- 174.1/275.7 kB 3.5 MB/s eta 0:00:01\n", + " -------------------------------------- 275.7/275.7 kB 4.3 MB/s eta 0:00:00\n", + " Preparing metadata (setup.py): started\n", + " Preparing metadata (setup.py): finished with status 'done'\n", + "Requirement already satisfied: matplotlib in d:\\anaconda\\lib\\site-packages (from lime) (3.7.1)\n", + "Requirement already satisfied: numpy in d:\\anaconda\\lib\\site-packages (from lime) (1.24.3)\n", + "Requirement already satisfied: scipy in d:\\anaconda\\lib\\site-packages (from lime) (1.10.1)\n", + "Requirement already satisfied: tqdm in d:\\anaconda\\lib\\site-packages (from lime) (4.65.0)\n", + "Requirement already satisfied: scikit-learn>=0.18 in d:\\anaconda\\lib\\site-packages (from lime) (1.3.0)\n", + "Requirement already satisfied: scikit-image>=0.12 in d:\\anaconda\\lib\\site-packages (from lime) (0.20.0)\n", + "Requirement already satisfied: networkx>=2.8 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (3.1)\n", + "Requirement already satisfied: pillow>=9.0.1 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (9.4.0)\n", + "Requirement already satisfied: imageio>=2.4.1 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (2.26.0)\n", + "Requirement already satisfied: tifffile>=2019.7.26 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (2021.7.2)\n", + "Requirement already satisfied: PyWavelets>=1.1.1 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (1.4.1)\n", + "Requirement already satisfied: packaging>=20.0 in c:\\users\\administrator\\appdata\\roaming\\python\\python311\\site-packages (from scikit-image>=0.12->lime) (23.1)\n", + "Requirement already satisfied: lazy_loader>=0.1 in d:\\anaconda\\lib\\site-packages (from scikit-image>=0.12->lime) (0.2)\n", + "Requirement already satisfied: joblib>=1.1.1 in d:\\anaconda\\lib\\site-packages (from scikit-learn>=0.18->lime) (1.2.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in d:\\anaconda\\lib\\site-packages (from scikit-learn>=0.18->lime) (2.2.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in d:\\anaconda\\lib\\site-packages (from matplotlib->lime) (1.0.5)\n", + "Requirement already satisfied: cycler>=0.10 in d:\\anaconda\\lib\\site-packages (from matplotlib->lime) (0.11.0)\n", + "Requirement already satisfied: fonttools>=4.22.0 in d:\\anaconda\\lib\\site-packages (from matplotlib->lime) (4.25.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in d:\\anaconda\\lib\\site-packages (from matplotlib->lime) (1.4.4)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in d:\\anaconda\\lib\\site-packages (from matplotlib->lime) (3.0.9)\n", + "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\administrator\\appdata\\roaming\\python\\python311\\site-packages (from matplotlib->lime) (2.8.2)\n", + "Requirement already satisfied: colorama in c:\\users\\administrator\\appdata\\roaming\\python\\python311\\site-packages (from tqdm->lime) (0.4.6)\n", + "Requirement already satisfied: six>=1.5 in c:\\users\\administrator\\appdata\\roaming\\python\\python311\\site-packages (from python-dateutil>=2.7->matplotlib->lime) (1.16.0)\n", + "Building wheels for collected packages: lime\n", + " Building wheel for lime (setup.py): started\n", + " Building wheel for lime (setup.py): finished with status 'done'\n", + " Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283846 sha256=fd212909233cda90b2608b5a7687f2f6daf8cf45d1bd5635eea43293fcae7cc0\n", + " Stored in directory: c:\\users\\administrator\\appdata\\local\\pip\\cache\\wheels\\85\\fa\\a3\\9c2d44c9f3cd77cf4e533b58900b2bf4487f2a17e8ec212a3d\n", + "Successfully built lime\n", + "Installing collected packages: lime\n", + "Successfully installed lime-0.2.0.1\n" + ] + } + ], + "source": [ + "pip install lime" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import lime\n", + "import pandas as pd\n", + "import sklearn.ensemble\n", + "import numpy as np\n", + "import lime.lime_tabular" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "rf = sklearn.ensemble.RandomForestRegressor(n_estimators=100)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
RandomForestRegressor()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "RandomForestRegressor()" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(100, 10)\n", + "y = X[:,0]\n", + "rf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def LIME(X, y, fit):\n", + " \"\"\"\n", + " Compute LIME local importance for each feature and sample.\n", + " :param X: design matrix\n", + " :param y: response\n", + " :param fit: fitted model of interest (tree-based)\n", + " :return: dataframe of shape: (n_samples, n_features)\n", + "\n", + " \"\"\"\n", + "\n", + " \n", + " np.random.seed(1)\n", + " num_samples, num_features = X.shape\n", + " result = np.zeros((num_samples, num_features))\n", + " explainer = lime.lime_tabular.LimeTabularExplainer(X, verbose=False, mode='regression')\n", + " for i in range(num_samples):\n", + " exp = explainer.explain_instance(X[i], fit.predict, num_features=num_features)\n", + " original_feature_importance = exp.as_map()[1]\n", + " sorted_feature_importance = sorted(original_feature_importance, key=lambda x: x[0])\n", + " for j in range(num_features):\n", + " result[i,j] = abs(sorted_feature_importance[j][1])\n", + " # Convert the array to a DataFrame\n", + " result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])\n", + "\n", + " return result_table" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(0, 0.21946497746237179),\n", + " (2, 0.016300293940867162),\n", + " (6, 0.014334259010884842),\n", + " (5, -0.013599234218149058),\n", + " (1, -0.010064527457576576),\n", + " (9, -0.00982332987593337),\n", + " (4, -0.007714466944431451),\n", + " (3, -0.0076925705634553485),\n", + " (7, 0.001754464109452201),\n", + " (8, -0.0003095663892797081)]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.random.seed(1)\n", + "explainer = lime.lime_tabular.LimeTabularExplainer(X, verbose=False, mode='regression')\n", + "exp = explainer.explain_instance(X[0], rf.predict, num_features=10)\n", + "exp.as_map()[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Feature_0Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9
00.2194650.0100650.0163000.0076930.0077140.0135990.0143340.0017540.0003100.009823
10.1716770.0045220.0159060.0068410.0048560.0101850.0064670.0007080.0059850.017116
20.2178540.0076750.0085740.0034340.0092130.0052860.0068830.0059050.0087290.009954
30.5015070.0051700.0031920.0030660.0142720.0039020.0064600.0062220.0014130.006189
40.2290030.0009100.0099000.0014570.0098450.0001350.0123310.0050810.0001840.011474
.................................
950.5059640.0002290.0146320.0005210.0040260.0016700.0016110.0048970.0059700.000782
960.4928300.0025740.0121790.0041590.0077350.0097300.0037870.0103330.0040510.000026
970.1775080.0048060.0007740.0054790.0081250.0071910.0023430.0033810.0063340.013973
980.5525410.0059580.0053250.0098920.0034530.0006620.0058280.0051840.0033140.002250
990.5461160.0054290.0068830.0019420.0003460.0031260.0075020.0043870.0004910.002831
\n", + "

100 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " Feature_0 Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 \\\n", + "0 0.219465 0.010065 0.016300 0.007693 0.007714 0.013599 \n", + "1 0.171677 0.004522 0.015906 0.006841 0.004856 0.010185 \n", + "2 0.217854 0.007675 0.008574 0.003434 0.009213 0.005286 \n", + "3 0.501507 0.005170 0.003192 0.003066 0.014272 0.003902 \n", + "4 0.229003 0.000910 0.009900 0.001457 0.009845 0.000135 \n", + ".. ... ... ... ... ... ... \n", + "95 0.505964 0.000229 0.014632 0.000521 0.004026 0.001670 \n", + "96 0.492830 0.002574 0.012179 0.004159 0.007735 0.009730 \n", + "97 0.177508 0.004806 0.000774 0.005479 0.008125 0.007191 \n", + "98 0.552541 0.005958 0.005325 0.009892 0.003453 0.000662 \n", + "99 0.546116 0.005429 0.006883 0.001942 0.000346 0.003126 \n", + "\n", + " Feature_6 Feature_7 Feature_8 Feature_9 \n", + "0 0.014334 0.001754 0.000310 0.009823 \n", + "1 0.006467 0.000708 0.005985 0.017116 \n", + "2 0.006883 0.005905 0.008729 0.009954 \n", + "3 0.006460 0.006222 0.001413 0.006189 \n", + "4 0.012331 0.005081 0.000184 0.011474 \n", + ".. ... ... ... ... \n", + "95 0.001611 0.004897 0.005970 0.000782 \n", + "96 0.003787 0.010333 0.004051 0.000026 \n", + "97 0.002343 0.003381 0.006334 0.013973 \n", + "98 0.005828 0.005184 0.003314 0.002250 \n", + "99 0.007502 0.004387 0.000491 0.002831 \n", + "\n", + "[100 rows x 10 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LIME( X, y, rf)" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -40,6 +433,403 @@ "df = pd.read_csv(\"./results/mdi_local.two_subgroups_linear_sims/varying_heritability_sample_row_n/seed331/results.csv\")" ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
repsample_row_nsample_row_n_nameheritabilityheritability_namen_estimatorsmin_samples_leafmax_featuresmodelfi...timesplit_seedindexvarlocal_fi_score_group1_meanlocal_fi_score_group2_meantrue_support_group1true_support_group2cor_with_signal_group1cor_with_signal_group2
001001000.10.110050.33RFPermutation...82.0613163310029.68875338.9637311.00.0NaN0.775961
101001000.10.110050.33RFPermutation...82.0613163310129.68875338.9637311.00.0NaN0.290743
201001000.10.110050.33RFPermutation...82.0613163310229.68875338.9637311.00.0NaN0.082066
301001000.10.110050.33RFPermutation...82.0613163310329.71830338.9567611.00.0NaN0.439064
401001000.10.110050.33RFPermutation...82.0613163310429.68875338.9637311.00.0NaN0.405487
..................................................................
84430100010000.80.810050.33RFTreeSHAP...4.117252331312590.0356050.0313080.00.00.5168050.363219
84440100010000.80.810050.33RFTreeSHAP...4.117252331312600.0024860.0021890.00.00.6890100.306485
84450100010000.80.810050.33RFTreeSHAP...4.117252331312610.0029690.0023970.00.00.5616310.337694
84460100010000.80.810050.33RFTreeSHAP...4.117252331312620.0425980.0383440.00.00.7943730.522991
84470100010000.80.810050.33RFTreeSHAP...4.117252331312630.0095690.0075200.00.00.7480930.487511
\n", + "

8448 rows × 23 columns

\n", + "
" + ], + "text/plain": [ + " rep sample_row_n sample_row_n_name heritability heritability_name \\\n", + "0 0 100 100 0.1 0.1 \n", + "1 0 100 100 0.1 0.1 \n", + "2 0 100 100 0.1 0.1 \n", + "3 0 100 100 0.1 0.1 \n", + "4 0 100 100 0.1 0.1 \n", + "... ... ... ... ... ... \n", + "8443 0 1000 1000 0.8 0.8 \n", + "8444 0 1000 1000 0.8 0.8 \n", + "8445 0 1000 1000 0.8 0.8 \n", + "8446 0 1000 1000 0.8 0.8 \n", + "8447 0 1000 1000 0.8 0.8 \n", + "\n", + " n_estimators min_samples_leaf max_features model fi ... \\\n", + "0 100 5 0.33 RF Permutation ... \n", + "1 100 5 0.33 RF Permutation ... \n", + "2 100 5 0.33 RF Permutation ... \n", + "3 100 5 0.33 RF Permutation ... \n", + "4 100 5 0.33 RF Permutation ... \n", + "... ... ... ... ... ... ... \n", + "8443 100 5 0.33 RF TreeSHAP ... \n", + "8444 100 5 0.33 RF TreeSHAP ... \n", + "8445 100 5 0.33 RF TreeSHAP ... \n", + "8446 100 5 0.33 RF TreeSHAP ... \n", + "8447 100 5 0.33 RF TreeSHAP ... \n", + "\n", + " time split_seed index var local_fi_score_group1_mean \\\n", + "0 82.061316 331 0 0 29.688753 \n", + "1 82.061316 331 0 1 29.688753 \n", + "2 82.061316 331 0 2 29.688753 \n", + "3 82.061316 331 0 3 29.718303 \n", + "4 82.061316 331 0 4 29.688753 \n", + "... ... ... ... ... ... \n", + "8443 4.117252 331 31 259 0.035605 \n", + "8444 4.117252 331 31 260 0.002486 \n", + "8445 4.117252 331 31 261 0.002969 \n", + "8446 4.117252 331 31 262 0.042598 \n", + "8447 4.117252 331 31 263 0.009569 \n", + "\n", + " local_fi_score_group2_mean true_support_group1 true_support_group2 \\\n", + "0 38.963731 1.0 0.0 \n", + "1 38.963731 1.0 0.0 \n", + "2 38.963731 1.0 0.0 \n", + "3 38.956761 1.0 0.0 \n", + "4 38.963731 1.0 0.0 \n", + "... ... ... ... \n", + "8443 0.031308 0.0 0.0 \n", + "8444 0.002189 0.0 0.0 \n", + "8445 0.002397 0.0 0.0 \n", + "8446 0.038344 0.0 0.0 \n", + "8447 0.007520 0.0 0.0 \n", + "\n", + " cor_with_signal_group1 cor_with_signal_group2 \n", + "0 NaN 0.775961 \n", + "1 NaN 0.290743 \n", + "2 NaN 0.082066 \n", + "3 NaN 0.439064 \n", + "4 NaN 0.405487 \n", + "... ... ... \n", + "8443 0.516805 0.363219 \n", + "8444 0.689010 0.306485 \n", + "8445 0.561631 0.337694 \n", + "8446 0.794373 0.522991 \n", + "8447 0.748093 0.487511 \n", + "\n", + "[8448 rows x 23 columns]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, { "cell_type": "code", "execution_count": 7, diff --git a/feature_importance/scripts/competing_methods_local.py b/feature_importance/scripts/competing_methods_local.py index f59423f..23404fd 100644 --- a/feature_importance/scripts/competing_methods_local.py +++ b/feature_importance/scripts/competing_methods_local.py @@ -8,6 +8,9 @@ from functools import reduce import shap +import lime +import lime.lime_tabular +from imodels.importance.rf_plus import RandomForestPlusRegressor, RandomForestPlusClassifier def tree_shap_local(X, y, fit): @@ -75,17 +78,15 @@ def permutation_local(X, y, fit, num_permutations=100): return result_table ##########To Do for Zach: Please add the implementation of local MDI and MDI+ below########## -def MDI_plus_local(X, y, fit): +def MDI_local_sub_stumps(X, y, fit): """ - Compute local MDI+ importance for each feature and sample. + Compute local MDI importance for each feature and sample. :param X: design matrix :param y: response :param fit: fitted model of interest (tree-based) :return: dataframe of shape: (n_samples, n_features) """ - - ## To Do for Zach: Please add the implementation of local MDI+ below num_samples, num_features = X.shape @@ -96,29 +97,54 @@ def MDI_plus_local(X, y, fit): return result_table - -def MDI_local(X, y, fit): +def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs): """ - Compute local MDI importance for each feature and sample. - :param X: design matrix - :param y: response - :param fit: fitted model of interest (tree-based) - :return: dataframe of shape: (n_samples, n_features) - + Wrapper around MDI+ object to get feature importance scores + + :param X: ndarray of shape (n_samples, n_features) + The covariate matrix. If a pd.DataFrame object is supplied, then + the column names are used in the output + :param y: ndarray of shape (n_samples, n_targets) + The observed responses. + :param rf_model: scikit-learn random forest object or None + The RF model to be used for interpretation. If None, then a new + RandomForestRegressor or RandomForestClassifier is instantiated. + :param kwargs: additional arguments to pass to + RandomForestPlusRegressor or RandomForestPlusClassifier class. + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDI+ score """ - ## To Do for Zach: Please add the implementation of local MDI below - num_samples, num_features = X.shape - - - result = None - - # Convert the array to a DataFrame - result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)]) - - return result_table - -def LIME(X, y, fit): + if isinstance(fit, RegressorMixin): + RFPlus = RandomForestPlusRegressor + elif isinstance(fit, ClassifierMixin): + RFPlus = RandomForestPlusClassifier + else: + raise ValueError("Unknown task.") + rf_plus_model = RFPlus(rf_model=fit, **kwargs) + rf_plus_model.fit(X, y) + try: + mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns) + if return_stability_scores: + stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25) + except ValueError as e: + if str(e) == 'Transformer representation was empty for all trees.': + mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance']) + if isinstance(X, pd.DataFrame): + mdi_plus_scores.index = X.columns + mdi_plus_scores.index.name = 'var' + mdi_plus_scores.reset_index(inplace=True) + stability_scores = None + else: + raise + mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_ + if return_stability_scores: + mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1) + + return mdi_plus_scores + +def lime_local(X, y, fit): """ Compute LIME local importance for each feature and sample. :param X: design matrix @@ -128,12 +154,16 @@ def LIME(X, y, fit): """ - ## To Do for Zach: Please add the implementation of local MDI below + np.random.seed(1) num_samples, num_features = X.shape - - - result = None - + result = np.zeros((num_samples, num_features)) + explainer = lime.lime_tabular.LimeTabularExplainer(X, verbose=False, mode='regression') + for i in range(num_samples): + exp = explainer.explain_instance(X[i], fit.predict, num_features=num_features) + original_feature_importance = exp.as_map()[1] + sorted_feature_importance = sorted(original_feature_importance, key=lambda x: x[0]) + for j in range(num_features): + result[i,j] = abs(sorted_feature_importance[j][1]) # Convert the array to a DataFrame result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)]) diff --git a/feature_importance/util.py b/feature_importance/util.py index 8b0b97e..cb95e03 100644 --- a/feature_importance/util.py +++ b/feature_importance/util.py @@ -13,7 +13,7 @@ from sklearn.preprocessing import label_binarize from sklearn.utils._encode import _unique from sklearn import metrics -# from imodels.importance.ppms import huber_loss +from imodels.importance.ppms import huber_loss DATASET_PATH = oj(dirname(os.path.realpath(__file__)), 'data') @@ -144,8 +144,8 @@ def neg_mean_absolute_error(y_true, y_pred, **kwargs): return -mean_absolute_error(y_true, y_pred, **kwargs) -# def neg_huber_loss(y_true, y_pred, **kwargs): -# return -huber_loss(y_true, y_pred, **kwargs) +def neg_huber_loss(y_true, y_pred, **kwargs): + return -huber_loss(y_true, y_pred, **kwargs) def restricted_roc_auc_score(y_true, y_score, ignored_indices=[]):