-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update hierarchy, build treeshap local,
- Loading branch information
1 parent
7e3390e
commit 85c21f1
Showing
6 changed files
with
188 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/dgp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import sys | ||
sys.path.append("../..") | ||
from feature_importance.scripts.simulations_util import * | ||
|
||
|
||
X_DGP = sample_real_X | ||
X_PARAMS_DICT = { | ||
"fpath": "data/X_splicing_cleaned.csv", | ||
"sample_row_n": None, | ||
"sample_col_n": None | ||
} | ||
### Update start for local MDI+ | ||
Y_DGP = linear_model_two_groups | ||
Y_PARAMS_DICT = { | ||
"beta": 1, | ||
"sigma": None, | ||
"heritability": 0.4, | ||
"s": 5 | ||
} | ||
### Update for local MDI+ done | ||
|
||
# # vary one parameter | ||
# VARY_PARAM_NAME = "sample_row_n" | ||
# VARY_PARAM_VALS = {"100": 100, "250": 250, "500": 500, "1000": 1000} | ||
|
||
# vary two parameters in a grid | ||
VARY_PARAM_NAME = ["heritability", "sample_row_n"] | ||
VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, | ||
"sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000}} | ||
|
||
# # vary over n_estimators in RF model in models.py | ||
# VARY_PARAM_NAME = "n_estimators" | ||
# VARY_PARAM_VALS = {"placeholder": 0} |
15 changes: 15 additions & 0 deletions
15
feature_importance/fi_config/mdi_local/two_subgroups_linear_sims/models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from sklearn.ensemble import RandomForestRegressor | ||
from feature_importance.util import ModelConfig, FIModelConfig | ||
from feature_importance.scripts.competing_methods_local import tree_shap_local | ||
|
||
# N_ESTIMATORS=[50, 100, 500, 1000] | ||
ESTIMATORS = [ | ||
[ModelConfig('RF', RandomForestRegressor, model_type='tree', | ||
other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33})], | ||
# [ModelConfig('RF', RandomForestRegressor, model_type='tree', vary_param="n_estimators", vary_param_val=m, | ||
# other_params={'min_samples_leaf': 5, 'max_features': 0.33}) for m in N_ESTIMATORS] | ||
] | ||
|
||
FI_ESTIMATORS = [ | ||
[FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')] | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import sys | ||
import pandas as pd | ||
import numpy as np | ||
import sklearn.base | ||
from sklearn.base import RegressorMixin, ClassifierMixin | ||
from functools import reduce | ||
|
||
import shap | ||
|
||
|
||
def tree_shap_local(X, y, fit): | ||
""" | ||
Compute average treeshap value across observations | ||
:param X: design matrix | ||
:param y: response | ||
:param fit: fitted model of interest (tree-based) | ||
:return: dataframe - [Var, Importance] | ||
Var: variable name | ||
Importance: average absolute shap value | ||
""" | ||
explainer = shap.TreeExplainer(fit) | ||
shap_values = explainer.shap_values(X, check_additivity=False) | ||
if sklearn.base.is_classifier(fit): | ||
def add_abs(a, b): | ||
return abs(a) + abs(b) | ||
results = reduce(add_abs, shap_values) | ||
else: | ||
results = abs(shap_values) | ||
result_table = pd.DataFrame(results) | ||
# results = results.mean(axis=0) | ||
# results = pd.DataFrame(data=results, columns=['importance']) | ||
# # Use column names from dataframe if possible | ||
# if isinstance(X, pd.DataFrame): | ||
# results.index = X.columns | ||
# results.index.name = 'var' | ||
# results.reset_index(inplace=True) | ||
|
||
return result_table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters