Skip to content

Commit

Permalink
Simulations for MDI+ (#16)
Browse files Browse the repository at this point in the history
* Add code and notebooks to reproduce simulations and plots for "MDI+: A Flexible Random Forest-Based Feature Importance Framework"
  • Loading branch information
tiffanymtang authored Jul 4, 2023
1 parent b778639 commit 779ac00
Show file tree
Hide file tree
Showing 242 changed files with 17,984 additions and 0 deletions.
475 changes: 475 additions & 0 deletions feature_importance/01_run_importance_simulations.py

Large diffs are not rendered by default.

402 changes: 402 additions & 0 deletions feature_importance/02_run_importance_real_data.py

Large diffs are not rendered by default.

422 changes: 422 additions & 0 deletions feature_importance/03_run_prediction_simulations.py

Large diffs are not rendered by default.

333 changes: 333 additions & 0 deletions feature_importance/04_run_prediction_real_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
# Example usage: run in command line
# cd feature_importance/
# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache
# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd

import copy
import os
from os.path import join as oj
import glob
import argparse
import pickle as pkl
import time
import warnings
from scipy import stats
import dask
from dask.distributed import Client
import numpy as np
import pandas as pd
from tqdm import tqdm
import sys
from collections import defaultdict
from typing import Callable, List, Tuple
import itertools
from functools import partial

sys.path.append(".")
sys.path.append("..")
sys.path.append("../..")
import fi_config
from util import ModelConfig, apply_splitting_strategy, auroc_score, auprc_score

from sklearn.metrics import accuracy_score, f1_score, recall_score, \
precision_score, average_precision_score, r2_score, explained_variance_score, \
mean_squared_error, mean_absolute_error, log_loss

warnings.filterwarnings("ignore", message="Bins whose width")


def compare_estimators(estimators: List[ModelConfig],
X, y,
metrics: List[Tuple[str, Callable]],
args, rep) -> Tuple[dict, dict]:
"""Calculates results given estimators, feature importance estimators, and datasets.
Called in run_comparison
"""
if type(estimators) != list:
raise Exception("First argument needs to be a list of Models")
if type(metrics) != list:
raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs")

# initialize results
results = defaultdict(lambda: [])

if args.splitting_strategy is not None:
X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(
X, y, args.splitting_strategy, args.split_seed + rep)
else:
X_train = X
X_tune = X
X_test = X
y_train = y
y_tune = y
y_test = y

# loop over model estimators
for model in tqdm(estimators, leave=False):
est = model.cls(**model.kwargs)

start = time.time()
est.fit(X_train, y_train)
end = time.time()

metric_results = {'model': model.name}
y_pred = est.predict(X_test)
if args.mode != 'regression':
y_pred_proba = est.predict_proba(X_test)
if args.mode == 'binary_classification':
y_pred_proba = y_pred_proba[:, 1]
else:
y_pred_proba = y_pred
for met_name, met in metrics:
if met is not None:
if args.mode == 'regression' \
or met_name in ['accuracy', 'f1', 'precision', 'recall']:
metric_results[met_name] = met(y_test, y_pred)
else:
metric_results[met_name] = met(y_test, y_pred_proba)
metric_results['predictions'] = copy.deepcopy(pd.DataFrame(y_pred_proba))
metric_results['time'] = end - start

# initialize results with metadata and metric results
kwargs: dict = model.kwargs # dict
for k in kwargs:
results[k].append(kwargs[k])
for met_name, met_val in metric_results.items():
results[met_name].append(met_val)

return results


def run_comparison(rep: int,
path: str,
X, y,
metrics: List[Tuple[str, Callable]],
estimators: List[ModelConfig],
args):
estimator_name = estimators[0].name.split(' - ')[0]
model_comparison_file = oj(path, f'{estimator_name}_comparisons.pkl')
if args.parallel_id is not None:
model_comparison_file = f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.'))

if os.path.isfile(model_comparison_file) and not args.ignore_cache:
print(f'{estimator_name} results already computed and cached. use --ignore_cache to recompute')
return

results = compare_estimators(estimators=estimators,
X=X, y=y,
metrics=metrics,
args=args,
rep=rep)

estimators_list = [e.name for e in estimators]

df = pd.DataFrame.from_dict(results)
df['split_seed'] = args.split_seed + rep
if args.nosave_cols is not None:
nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")])
else:
nosave_cols = []
for col in nosave_cols:
if col in df.columns:
df = df.drop(columns=[col])

output_dict = {
# metadata
'sim_name': args.config,
'estimators': estimators_list,

# actual values
'df': df,
}
pkl.dump(output_dict, open(model_comparison_file, 'wb'))

return df


def reformat_results(results):
results = results.reset_index().drop(columns=['index'])
predictions = pd.concat(results.pop('predictions').to_dict()). \
reset_index(level=[0, 1]).rename(columns={'level_0': 'index', 'level_1': 'sample_id'})
results_df = pd.merge(results, predictions, left_index=True, right_on="index")
return results_df


def get_metrics(mode: str = 'regression'):
if mode == 'binary_classification':
return [
('rocauc', auroc_score),
('prauc', auprc_score),
('logloss', log_loss),
('accuracy', accuracy_score),
('f1', f1_score),
('recall', recall_score),
('precision', precision_score),
('avg_precision', average_precision_score)
]
elif mode == 'multiclass_classification':
return [
('rocauc', partial(auroc_score, multi_class="ovr")),
('prauc', partial(auprc_score, multi_class="ovr")),
('logloss', log_loss),
('accuracy', accuracy_score),
('f1', partial(f1_score, average='micro')),
('recall', partial(recall_score, average='micro')),
('precision', partial(precision_score, average='micro'))
]
elif mode == 'regression':
return [
('r2', r2_score),
('explained_variance', explained_variance_score),
('mean_squared_error', mean_squared_error),
('mean_absolute_error', mean_absolute_error),
]


def run_simulation(i, path, Xpath, ypath, ests, metrics, args):
X_df = pd.read_csv(Xpath)
y_df = pd.read_csv(ypath)
if args.subsample_n is not None:
if args.subsample_n < X_df.shape[0]:
keep_rows = np.random.choice(X_df.shape[0], args.subsample_n, replace=False)
X_df = X_df.iloc[keep_rows]
y_df = y_df.iloc[keep_rows]
if args.response_idx is None:
keep_cols = y_df.columns
else:
keep_cols = [args.response_idx]
for col in keep_cols:
y = y_df[col].to_numpy().ravel()
keep_idx = ~pd.isnull(y)
X = X_df[keep_idx].to_numpy()
y = y[keep_idx]
if y_df.shape[1] > 1:
output_path = oj(path, col)
else:
output_path = path
os.makedirs(oj(output_path, "rep" + str(i)), exist_ok=True)
for est in ests:
for idx in range(len(est)):
if "random_state" in est[idx].kwargs.keys():
est[idx].kwargs["random_state"] = i
results = run_comparison(
rep=i,
path=oj(output_path, "rep" + str(i)),
X=X, y=y,
metrics=metrics,
estimators=est,
args=args
)

return True


if __name__ == '__main__':

parser = argparse.ArgumentParser()

default_dir = os.getenv("SCRATCH")
if default_dir is not None:
default_dir = oj(default_dir, "feature_importance", "results")
else:
default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results')

parser.add_argument('--nreps', type=int, default=2)
parser.add_argument('--mode', type=str, default='regression')
parser.add_argument('--model', type=str, default=None)
parser.add_argument('--config', type=str, default='mdi_plus.prediction_sims.ccle_rnaseq_regression-')
parser.add_argument('--response_idx', type=str, default=None)
parser.add_argument('--subsample_n', type=int, default=None)
parser.add_argument('--nosave_cols', type=str, default=None)

# for multiple reruns, should support varying split_seed
parser.add_argument('--ignore_cache', action='store_true', default=False)
parser.add_argument('--splitting_strategy', type=str, default="train-test")
parser.add_argument('--verbose', action='store_true', default=True)
parser.add_argument('--parallel', action='store_true', default=False)
parser.add_argument('--parallel_id', nargs='+', type=int, default=None)
parser.add_argument('--n_cores', type=int, default=None)
parser.add_argument('--split_seed', type=int, default=0)
parser.add_argument('--results_path', type=str, default=default_dir)

args = parser.parse_args()

assert args.splitting_strategy in {
'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata'}
assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'}

if args.parallel:
if args.n_cores is None:
print(os.getenv("SLURM_CPUS_ON_NODE"))
n_cores = int(os.getenv("SLURM_CPUS_ON_NODE"))
else:
n_cores = args.n_cores
client = Client(n_workers=n_cores)

ests, _, Xpath, ypath = fi_config.get_fi_configs(args.config, real_data=True)
metrics = get_metrics(args.mode)

if args.model:
ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests))

if len(ests) == 0:
raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model)
if args.verbose:
print('running', args.config,
'ests', ests)
print('\tsaving to', args.results_path)

results_dir = oj(args.results_path, args.config)
path = oj(results_dir, "seed" + str(args.split_seed))
os.makedirs(path, exist_ok=True)

eval_out = defaultdict(list)

if args.parallel:
futures = [dask.delayed(run_simulation)(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)]
results = dask.compute(*futures)
else:
results = [run_simulation(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)]
assert all(results)

print('completed all experiments successfully!')

# get model file names
model_comparison_files_all = []
for est in ests:
estimator_name = est[0].name.split(' - ')[0]
model_comparison_file = f'{estimator_name}_comparisons.pkl'
model_comparison_files_all.append(model_comparison_file)

# aggregate results
y_df = pd.read_csv(ypath)
results_list = []
for col in y_df.columns:
if y_df.shape[1] > 1:
output_path = oj(path, col)
else:
output_path = path
for i in range(args.nreps):
all_files = glob.glob(oj(output_path, 'rep' + str(i), '*'))
model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])

if len(model_files) == 0:
print('No files found at ', oj(output_path, 'rep' + str(i)))
continue

results = pd.concat(
[pkl.load(open(f, 'rb'))['df'] for f in model_files],
axis=0
)
results.insert(0, 'rep', i)
if y_df.shape[1] > 1:
results.insert(1, 'y_task', col)
results_list.append(results)

results_merged = pd.concat(results_list, axis=0)
pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb'))
results_df = reformat_results(results_merged)
results_df.to_csv(oj(path, 'results.csv'), index=False)

print('merged and saved all experiment results successfully!')

# %%
Empty file added feature_importance/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions feature_importance/fi_config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import importlib

def get_fi_configs(config_name, real_data=False):
if real_data:
ests = importlib.import_module(f'fi_config.{config_name}.models')
dgp = importlib.import_module(f'fi_config.{config_name}.dgp')
return ests.ESTIMATORS, ests.FI_ESTIMATORS, dgp.X_PATH, dgp.Y_PATH
else:
ests = importlib.import_module(f'fi_config.{config_name}.models')
dgp = importlib.import_module(f'fi_config.{config_name}.dgp')
return ests.ESTIMATORS, ests.FI_ESTIMATORS, \
dgp.X_DGP, dgp.X_PARAMS_DICT, dgp.Y_DGP, dgp.Y_PARAMS_DICT, \
dgp.VARY_PARAM_NAME, dgp.VARY_PARAM_VALS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sys
sys.path.append("../..")
from feature_importance.scripts.simulations_util import *

X_DGP = sample_real_X
X_PARAMS_DICT = {
"fpath": "data/X_ccle_rnaseq_cleaned.csv",
"sample_row_n": None,
"sample_col_n": 1000
}
Y_DGP = logistic_hier_model
Y_PARAMS_DICT = {
"m":3,
"r":2,
"beta": 1,
"frac_label_corruption": None
}

VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None},
"sample_row_n": {"100": 100, "250": 250, "472": 472}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sklearn.ensemble import RandomForestClassifier
from feature_importance.util import ModelConfig, FIModelConfig
from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
from imodels.importance.rf_plus import _fast_r2_score
from imodels.importance.ppms import RidgeClassifierPPM


ESTIMATORS = [
[ModelConfig('RF', RandomForestClassifier, model_type='tree',
other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
]

FI_ESTIMATORS = [
[FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
[FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
[FIModelConfig('MDI', tree_mdi, model_type='tree')],
[FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
[FIModelConfig('MDA', tree_mda, model_type='tree')],
[FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
]
Loading

0 comments on commit 779ac00

Please sign in to comment.