Skip to content

Commit

Permalink
FIx all bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Jan 14, 2024
1 parent 730298f commit 18312c0
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 936 deletions.
1 change: 1 addition & 0 deletions feature_importance/01_run_importance_local_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def compare_estimators(estimators: List[ModelConfig],
}
start = time.time()
local_fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
print(local_fi_score)
assert local_fi_score.shape == X_test.shape
n_local_fi_score = len(local_fi_score)
local_fi_score_group1 = local_fi_score.iloc[range(n_local_fi_score // 2)].values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
X_DGP = sample_normal_X
X_PARAMS_DICT = {
"n": 1200,
"d": 50,
"d": 20,
"mean": 0,
"scale": 1
}
Expand All @@ -26,7 +26,7 @@
# vary two parameters in a grid
VARY_PARAM_NAME = ["heritability", "sample_row_n"]
VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
"sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000}}
"n": {"100": 100, "250": 250, "500": 500, "1000": 1000}}

# # vary over n_estimators in RF model in models.py
# VARY_PARAM_NAME = "n_estimators"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
]

FI_ESTIMATORS = [
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, model_type='tree')],
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, ascending = False, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, ascending = False, model_type='tree')],
[FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')],
[FIModelConfig('Permutation', permutation_local, model_type='tree')],
[FIModelConfig('LIME', lime_local, model_type='tree')],
Expand Down
Loading

0 comments on commit 18312c0

Please sign in to comment.