Skip to content

Commit

Permalink
Add MDI+ local
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Jan 14, 2024
1 parent 850dd2e commit b519455
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
[FIModelConfig('Permutation', permutation_local, model_type='tree')],
[FIModelConfig('LIME', lime_local, model_type='tree')],
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, model_type='tree')],
]
60 changes: 40 additions & 20 deletions feature_importance/scripts/competing_methods_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
import lime
import lime.lime_tabular
from imodels.importance.rf_plus import RandomForestPlusRegressor, RandomForestPlusClassifier
from imodels.importance.rf_plus import _fast_r2_score
from sklearn.metrics import r2_score, mean_absolute_error, accuracy_score, roc_auc_score, mean_squared_error


def neg_mae(y_true, y_pred, **kwargs):
"""
Evaluates negative mean absolute error
"""
return -mean_absolute_error(y_true, y_pred, **kwargs)

def tree_shap_local(X, y, fit):
"""
Compute average treeshap value across observations
Expand Down Expand Up @@ -89,8 +97,17 @@ def MDI_local_sub_stumps(X, y, fit):
"""
num_samples, num_features = X.shape

if isinstance(fit, RegressorMixin):
RFPlus = RandomForestPlusRegressor
elif isinstance(fit, ClassifierMixin):
RFPlus = RandomForestPlusClassifier
else:
raise ValueError("Unknown task.")
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)

result = None
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X, y, scoring_fns={"r2_score": _fast_r2_score, "negative_mae": neg_mae}, local_scoring_fns=True)
result = mdi_plus_scores["local"]["negative_mae"]

# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])
Expand Down Expand Up @@ -124,25 +141,28 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
raise ValueError("Unknown task.")
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns)
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
if str(e) == 'Transformer representation was empty for all trees.':
mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
if isinstance(X, pd.DataFrame):
mdi_plus_scores.index = X.columns
mdi_plus_scores.index.name = 'var'
mdi_plus_scores.reset_index(inplace=True)
stability_scores = None
else:
raise
mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
if return_stability_scores:
mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)

return mdi_plus_scores
# try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
result = mdi_plus_scores["local"]
# if return_stability_scores:
# stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
# except ValueError as e:
# if str(e) == 'Transformer representation was empty for all trees.':
# mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
# if isinstance(X, pd.DataFrame):
# mdi_plus_scores.index = X.columns
# mdi_plus_scores.index.name = 'var'
# mdi_plus_scores.reset_index(inplace=True)
# stability_scores = None
# else:
# raise
# mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)

result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

def lime_local(X, y, fit):
"""
Expand Down

0 comments on commit b519455

Please sign in to comment.