diff --git a/feature_importance/01_run_importance_simulations.py b/feature_importance/01_run_importance_simulations.py
new file mode 100644
index 0000000..874f431
--- /dev/null
+++ b/feature_importance/01_run_importance_simulations.py
@@ -0,0 +1,475 @@
+# Example usage: run in command line
+# cd feature_importance/
+# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache
+# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd
+
+import copy
+import os
+from os.path import join as oj
+import glob
+import argparse
+import pickle as pkl
+import time
+import warnings
+from scipy import stats
+import dask
+from dask.distributed import Client
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import sys
+from collections import defaultdict
+from typing import Callable, List, Tuple
+import itertools
+from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score
+
+sys.path.append(".")
+sys.path.append("..")
+sys.path.append("../..")
+import fi_config
+from util import ModelConfig, FIModelConfig, tp, fp, neg, pos, specificity_score, auroc_score, auprc_score, compute_nsg_feat_corr_w_sig_subspace, apply_splitting_strategy
+
+warnings.filterwarnings("ignore", message="Bins whose width")
+
+
+def compare_estimators(estimators: List[ModelConfig],
+ fi_estimators: List[FIModelConfig],
+ X, y, support: List,
+ metrics: List[Tuple[str, Callable]],
+ args, ) -> Tuple[dict, dict]:
+ """Calculates results given estimators, feature importance estimators, datasets, and metrics.
+ Called in run_comparison
+ """
+ if type(estimators) != list:
+ raise Exception("First argument needs to be a list of Models")
+ if type(metrics) != list:
+ raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs")
+
+ # initialize results
+ results = defaultdict(lambda: [])
+
+ # loop over model estimators
+ for model in tqdm(estimators, leave=False):
+ est = model.cls(**model.kwargs)
+
+ # get kwargs for all fi_ests
+ fi_kwargs = {}
+ for fi_est in fi_estimators:
+ fi_kwargs.update(fi_est.kwargs)
+
+ # get groups of estimators for each splitting strategy
+ fi_ests_dict = defaultdict(list)
+ for fi_est in fi_estimators:
+ fi_ests_dict[fi_est.splitting_strategy].append(fi_est)
+
+ # loop over splitting strategies
+ for splitting_strategy, fi_ests in fi_ests_dict.items():
+ # implement provided splitting strategy
+ if splitting_strategy is not None:
+ X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, splitting_strategy, args.split_seed)
+ else:
+ X_train = X
+ X_tune = X
+ X_test = X
+ y_train = y
+ y_tune = y
+ y_test = y
+
+ # fit model
+ est.fit(X_train, y_train)
+
+ # compute correlation between signal and nonsignal features
+ x_cor = np.empty(len(support))
+ x_cor[:] = np.NaN
+ x_cor[support == 0] = compute_nsg_feat_corr_w_sig_subspace(X_train[:, support == 1], X_train[:, support == 0])
+
+ # loop over fi estimators
+ for fi_est in fi_ests:
+ metric_results = {
+ 'model': model.name,
+ 'fi': fi_est.name,
+ 'splitting_strategy': splitting_strategy
+ }
+ start = time.time()
+ fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
+ end = time.time()
+ support_df = pd.DataFrame({"var": np.arange(len(support)),
+ "true_support": support,
+ "cor_with_signal": x_cor})
+ metric_results['fi_scores'] = pd.merge(copy.deepcopy(fi_score), support_df, on="var", how="left")
+ if np.max(support) != np.min(support):
+ for i, (met_name, met) in enumerate(metrics):
+ if met is not None:
+ imp_vals = copy.deepcopy(fi_score["importance"])
+ imp_vals[imp_vals == float("-inf")] = -sys.maxsize - 1
+ imp_vals[imp_vals == float("inf")] = sys.maxsize - 1
+ if fi_est.ascending:
+ imp_vals[np.isnan(imp_vals)] = -sys.maxsize - 1
+ metric_results[met_name] = met(support, imp_vals)
+ else:
+ imp_vals[np.isnan(imp_vals)] = sys.maxsize - 1
+ metric_results[met_name] = met(support, -imp_vals)
+ metric_results['time'] = end - start
+
+ # initialize results with metadata and metric results
+ kwargs: dict = model.kwargs # dict
+ for k in kwargs:
+ results[k].append(kwargs[k])
+ for k in fi_kwargs:
+ if k in fi_est.kwargs:
+ results[k].append(str(fi_est.kwargs[k]))
+ else:
+ results[k].append(None)
+ for met_name, met_val in metric_results.items():
+ results[met_name].append(met_val)
+ return results
+
+
+def run_comparison(path: str,
+ X, y, support: List,
+ metrics: List[Tuple[str, Callable]],
+ estimators: List[ModelConfig],
+ fi_estimators: List[FIModelConfig],
+ args):
+ estimator_name = estimators[0].name.split(' - ')[0]
+ fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_estimators) \
+ if fi_estimator.model_type in estimators[0].model_type]
+ model_comparison_files_all = [oj(path, f'{estimator_name}_{fi_estimator.name}_comparisons.pkl') \
+ for fi_estimator in fi_estimators_all]
+ if args.parallel_id is not None:
+ model_comparison_files_all = [f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) \
+ for model_comparison_file in model_comparison_files_all]
+
+ fi_estimators = []
+ model_comparison_files = []
+ for model_comparison_file, fi_estimator in zip(model_comparison_files_all, fi_estimators_all):
+ if os.path.isfile(model_comparison_file) and not args.ignore_cache:
+ print(
+ f'{estimator_name} with {fi_estimator.name} results already computed and cached. use --ignore_cache to recompute')
+ else:
+ fi_estimators.append(fi_estimator)
+ model_comparison_files.append(model_comparison_file)
+
+ if len(fi_estimators) == 0:
+ return
+
+ results = compare_estimators(estimators=estimators,
+ fi_estimators=fi_estimators,
+ X=X, y=y, support=support,
+ metrics=metrics,
+ args=args)
+
+ estimators_list = [e.name for e in estimators]
+ metrics_list = [m[0] for m in metrics]
+
+ df = pd.DataFrame.from_dict(results)
+ df['split_seed'] = args.split_seed
+ if args.nosave_cols is not None:
+ nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")])
+ else:
+ nosave_cols = []
+ for col in nosave_cols:
+ if col in df.columns:
+ df = df.drop(columns=[col])
+
+ for model_comparison_file, fi_estimator in zip(model_comparison_files, fi_estimators):
+ output_dict = {
+ # metadata
+ 'sim_name': args.config,
+ 'estimators': estimators_list,
+ 'fi_estimators': fi_estimator.name,
+ 'metrics': metrics_list,
+
+ # actual values
+ 'df': df.loc[df.fi == fi_estimator.name],
+ }
+ pkl.dump(output_dict, open(model_comparison_file, 'wb'))
+ return df
+
+
+def get_metrics():
+ return [('rocauc', auroc_score), ('prauc', auprc_score)]
+
+
+def reformat_results(results):
+ results = results.reset_index().drop(columns=['index'])
+ fi_scores = pd.concat(results.pop('fi_scores').to_dict()). \
+ reset_index(level=0).rename(columns={'level_0': 'index'})
+ results_df = pd.merge(results, fi_scores, left_index=True, right_on="index")
+ return results_df
+
+
+def run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, fi_ests, metrics, args):
+ os.makedirs(oj(path, val_name, "rep" + str(i)), exist_ok=True)
+ np.random.seed(i)
+ max_iter = 100
+ iter = 0
+ while iter <= max_iter: # regenerate data if y is constant
+ X = X_dgp(**X_params_dict)
+ y, support, beta = y_dgp(X, **y_params_dict, return_support=True)
+ if not all(y == y[0]):
+ break
+ iter += 1
+ if iter > max_iter:
+ raise ValueError("Response y is constant.")
+ if args.omit_vars is not None:
+ omit_vars = np.unique([int(x.strip()) for x in args.omit_vars.split(",")])
+ support = np.delete(support, omit_vars)
+ X = np.delete(X, omit_vars, axis=1)
+ del beta # note: beta is not currently supported when using omit_vars
+
+ for est in ests:
+ results = run_comparison(path=oj(path, val_name, "rep" + str(i)),
+ X=X, y=y, support=support,
+ metrics=metrics,
+ estimators=est,
+ fi_estimators=fi_ests,
+ args=args)
+ return True
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ default_dir = os.getenv("SCRATCH")
+ if default_dir is not None:
+ default_dir = oj(default_dir, "feature_importance", "results")
+ else:
+ default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results')
+
+ parser.add_argument('--nreps', type=int, default=2)
+ parser.add_argument('--model', type=str, default=None) # , default='c4')
+ parser.add_argument('--fi_model', type=str, default=None) # , default='c4')
+ parser.add_argument('--config', type=str, default='test')
+ parser.add_argument('--omit_vars', type=str, default=None) # comma-separated string of variables to omit
+ parser.add_argument('--nosave_cols', type=str, default="prediction_model")
+
+ # for multiple reruns, should support varying split_seed
+ parser.add_argument('--ignore_cache', action='store_true', default=False)
+ parser.add_argument('--verbose', action='store_true', default=True)
+ parser.add_argument('--parallel', action='store_true', default=False)
+ parser.add_argument('--parallel_id', nargs='+', type=int, default=None)
+ parser.add_argument('--n_cores', type=int, default=None)
+ parser.add_argument('--split_seed', type=int, default=0)
+ parser.add_argument('--results_path', type=str, default=default_dir)
+
+ # arguments for rmd output of results
+ parser.add_argument('--create_rmd', action='store_true', default=False)
+ parser.add_argument('--show_vars', type=int, default=None)
+
+ args = parser.parse_args()
+
+ if args.parallel:
+ if args.n_cores is None:
+ print(os.getenv("SLURM_CPUS_ON_NODE"))
+ n_cores = int(os.getenv("SLURM_CPUS_ON_NODE"))
+ else:
+ n_cores = args.n_cores
+ client = Client(n_workers=n_cores)
+
+ ests, fi_ests, \
+ X_dgp, X_params_dict, y_dgp, y_params_dict, \
+ vary_param_name, vary_param_vals = fi_config.get_fi_configs(args.config)
+
+ metrics = get_metrics()
+
+ if args.model:
+ ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests))
+ if args.fi_model:
+ fi_ests = list(filter(lambda x: args.fi_model.lower() == x[0].name.lower(), fi_ests))
+
+ if len(ests) == 0:
+ raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model)
+ if len(fi_ests) == 0:
+ raise ValueError('No valid FI estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model)
+ if args.verbose:
+ print('running', args.config,
+ 'ests', ests,
+ 'fi_ests', fi_ests)
+ print('\tsaving to', args.results_path)
+
+ if args.omit_vars is not None:
+ results_dir = oj(args.results_path, args.config + "_omitted_vars")
+ else:
+ results_dir = oj(args.results_path, args.config)
+
+ if isinstance(vary_param_name, list):
+ path = oj(results_dir, "varying_" + "_".join(vary_param_name), "seed" + str(args.split_seed))
+ else:
+ path = oj(results_dir, "varying_" + vary_param_name, "seed" + str(args.split_seed))
+ os.makedirs(path, exist_ok=True)
+
+ eval_out = defaultdict(list)
+
+ vary_type = None
+ if isinstance(vary_param_name, list): # multiple parameters are being varied
+ # get parameters that are being varied over and identify whether it's a DGP/method/fi_method argument
+ keys, values = zip(*vary_param_vals.items())
+ vary_param_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
+ vary_type = {}
+ for vary_param_dict in vary_param_dicts:
+ for param_name, param_val in vary_param_dict.items():
+ if param_name in X_params_dict.keys() and param_name in y_params_dict.keys():
+ raise ValueError('Cannot vary over parameter in both X and y DGPs.')
+ elif param_name in X_params_dict.keys():
+ vary_type[param_name] = "dgp"
+ X_params_dict[param_name] = vary_param_vals[param_name][param_val]
+ elif param_name in y_params_dict.keys():
+ vary_type[param_name] = "dgp"
+ y_params_dict[param_name] = vary_param_vals[param_name][param_val]
+ else:
+ est_kwargs = list(
+ itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))]))
+ fi_est_kwargs = list(
+ itertools.chain(*[list(fi_est.kwargs.keys()) for fi_est in list(itertools.chain(*fi_ests))]))
+ if param_name in est_kwargs:
+ vary_type[param_name] = "est"
+ elif param_name in fi_est_kwargs:
+ vary_type[param_name] = "fi_est"
+ else:
+ raise ValueError('Invalid vary_param_name.')
+
+ if args.parallel:
+ futures = [
+ dask.delayed(run_simulation)(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp,
+ y_params_dict, y_dgp, ests, fi_ests, metrics, args) for i in
+ range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [
+ run_simulation(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, y_params_dict,
+ y_dgp, ests, fi_ests, metrics, args) for i in range(args.nreps)]
+ assert all(results)
+
+ else: # only on parameter is being varied over
+ # get parameter that is being varied over and identify whether it's a DGP/method/fi_method argument
+ for val_name, val in vary_param_vals.items():
+ if vary_param_name in X_params_dict.keys() and vary_param_name in y_params_dict.keys():
+ raise ValueError('Cannot vary over parameter in both X and y DGPs.')
+ elif vary_param_name in X_params_dict.keys():
+ vary_type = "dgp"
+ X_params_dict[vary_param_name] = val
+ elif vary_param_name in y_params_dict.keys():
+ vary_type = "dgp"
+ y_params_dict[vary_param_name] = val
+ else:
+ est_kwargs = list(itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))]))
+ fi_est_kwargs = list(
+ itertools.chain(*[list(fi_est.kwargs.keys()) for fi_est in list(itertools.chain(*fi_ests))]))
+ if vary_param_name in est_kwargs:
+ vary_type = "est"
+ elif vary_param_name in fi_est_kwargs:
+ vary_type = "fi_est"
+ else:
+ raise ValueError('Invalid vary_param_name.')
+
+ if args.parallel:
+ futures = [
+ dask.delayed(run_simulation)(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests,
+ fi_ests, metrics, args) for i in range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, fi_ests,
+ metrics, args) for i in range(args.nreps)]
+ assert all(results)
+
+ print('completed all experiments successfully!')
+
+ # get model file names
+ model_comparison_files_all = []
+ for est in ests:
+ estimator_name = est[0].name.split(' - ')[0]
+ fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_ests) \
+ if fi_estimator.model_type in est[0].model_type]
+ model_comparison_files = [f'{estimator_name}_{fi_estimator.name}_comparisons.pkl' for fi_estimator in
+ fi_estimators_all]
+ model_comparison_files_all += model_comparison_files
+
+ # aggregate results
+ results_list = []
+ if isinstance(vary_param_name, list):
+ for vary_param_dict in vary_param_dicts:
+ val_name = "_".join(vary_param_dict.values())
+
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(path, val_name, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+
+ for param_name, param_val in vary_param_dict.items():
+ val = vary_param_vals[param_name][param_val]
+ if vary_type[param_name] == "dgp":
+ if np.isscalar(val):
+ results.insert(0, param_name, val)
+ else:
+ results.insert(0, param_name, [val for i in range(results.shape[0])])
+ results.insert(1, param_name + "_name", param_val)
+ elif vary_type[param_name] == "est" or vary_type[param_name] == "fi_est":
+ results.insert(0, param_name + "_name", copy.deepcopy(results[param_name]))
+ results.insert(0, 'rep', i)
+ results_list.append(results)
+ else:
+ for val_name, val in vary_param_vals.items():
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(path, val_name, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+ if vary_type == "dgp":
+ if np.isscalar(val):
+ results.insert(0, vary_param_name, val)
+ else:
+ results.insert(0, vary_param_name, [val for i in range(results.shape[0])])
+ results.insert(1, vary_param_name + "_name", val_name)
+ results.insert(2, 'rep', i)
+ elif vary_type == "est" or vary_type == "fi_est":
+ results.insert(0, vary_param_name + "_name", copy.deepcopy(results[vary_param_name]))
+ results.insert(1, 'rep', i)
+ results_list.append(results)
+ results_merged = pd.concat(results_list, axis=0)
+ pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb'))
+ results_df = reformat_results(results_merged)
+ results_df.to_csv(oj(path, 'results.csv'), index=False)
+
+ print('merged and saved all experiment results successfully!')
+
+ # create R markdown summary of results
+ if args.create_rmd:
+ if args.show_vars is None:
+ show_vars = 'NULL'
+ else:
+ show_vars = args.show_vars
+
+ if isinstance(vary_param_name, list):
+ vary_param_name = "; ".join(vary_param_name)
+
+ sim_rmd = os.path.basename(results_dir) + '_simulation_results.Rmd'
+ os.system(
+ 'cp {} \'{}\''.format(oj("rmd", "simulation_results.Rmd"), sim_rmd)
+ )
+ os.system(
+ 'Rscript -e "rmarkdown::render(\'{}\', params = list(results_dir = \'{}\', vary_param_name = \'{}\', seed = {}, keep_vars = {}), output_file = \'{}\', quiet = TRUE)"'.format(
+ sim_rmd,
+ results_dir, vary_param_name, str(args.split_seed), str(show_vars),
+ oj(path, "simulation_results.html"))
+ )
+ os.system('rm \'{}\''.format(sim_rmd))
+ print("created rmd of simulation results successfully!")
+
+# %%
diff --git a/feature_importance/02_run_importance_real_data.py b/feature_importance/02_run_importance_real_data.py
new file mode 100644
index 0000000..4a51745
--- /dev/null
+++ b/feature_importance/02_run_importance_real_data.py
@@ -0,0 +1,402 @@
+# Example usage: run in command line
+# cd feature_importance/
+# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache
+# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd
+
+import copy
+import os
+from os.path import join as oj
+import glob
+import argparse
+import pickle as pkl
+import time
+import warnings
+from scipy import stats
+import dask
+from dask.distributed import Client
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import sys
+from collections import defaultdict
+from typing import Callable, List, Tuple
+import itertools
+from functools import partial
+
+sys.path.append(".")
+sys.path.append("..")
+sys.path.append("../..")
+import fi_config
+from util import ModelConfig, FIModelConfig, apply_splitting_strategy, auroc_score, auprc_score
+
+from sklearn.metrics import accuracy_score, f1_score, recall_score, \
+ precision_score, average_precision_score, r2_score, explained_variance_score, \
+ mean_squared_error, mean_absolute_error, log_loss
+
+warnings.filterwarnings("ignore", message="Bins whose width")
+
+
+def compare_estimators(estimators: List[ModelConfig],
+ fi_estimators: List[FIModelConfig],
+ X, y, args, ) -> Tuple[dict, dict]:
+ """Calculates results given estimators, feature importance estimators, and datasets.
+ Called in run_comparison
+ """
+ if type(estimators) != list:
+ raise Exception("First argument needs to be a list of Models")
+
+ # initialize results
+ results = defaultdict(lambda: [])
+
+ # loop over model estimators
+ for model in tqdm(estimators, leave=False):
+ est = model.cls(**model.kwargs)
+
+ # get kwargs for all fi_ests
+ fi_kwargs = {}
+ for fi_est in fi_estimators:
+ fi_kwargs.update(fi_est.kwargs)
+
+ # get groups of estimators for each splitting strategy
+ fi_ests_dict = defaultdict(list)
+ for fi_est in fi_estimators:
+ fi_ests_dict[fi_est.splitting_strategy].append(fi_est)
+
+ # loop over splitting strategies
+ for splitting_strategy, fi_ests in fi_ests_dict.items():
+ # implement provided splitting strategy
+ if splitting_strategy is not None:
+ X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, splitting_strategy, args.split_seed)
+ if splitting_strategy == "train-test-prediction":
+ X_test_pred = copy.deepcopy(X_test)
+ y_test_pred = copy.deepcopy(y_test)
+ X_test = X_train
+ y_test = y_train
+ else:
+ X_train = X
+ X_tune = X
+ X_test = X
+ y_train = y
+ y_tune = y
+ y_test = y
+
+ # fit model
+ est.fit(X_train, y_train)
+
+ # loop over fi estimators
+ for fi_est in fi_ests:
+ metric_results = {
+ 'model': model.name,
+ 'fi': fi_est.name,
+ 'splitting_strategy': splitting_strategy
+ }
+ start = time.time()
+ fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
+ end = time.time()
+ metric_results['fi_scores'] = copy.deepcopy(fi_score)
+ metric_results['time'] = end - start
+
+ if splitting_strategy == "train-test-prediction" and args.eval_top_ks is not None:
+ fi_rankings = fi_score.sort_values("importance", ascending=not fi_est.ascending)
+ pred_est = copy.deepcopy(est)
+ metrics = get_metrics(args.mode)
+ if args.eval_top_ks == "auto":
+ eval_top_ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 15,
+ 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100]
+ else:
+ eval_top_ks = [int(k.strip()) for k in args.eval_top_ks.split(",")]
+ pred_metrics_out = []
+ for k in eval_top_ks:
+ top_k_features = fi_rankings["var"][:k]
+ Xk_train = X_train[:, top_k_features]
+ Xk_test_pred = X_test_pred[:, top_k_features]
+ pred_est.fit(Xk_train, y_train)
+ y_pred = pred_est.predict(Xk_test_pred)
+ if args.mode != 'regression':
+ y_pred_proba = pred_est.predict_proba(Xk_test_pred)
+ if args.mode == 'binary_classification':
+ y_pred_proba = y_pred_proba[:, 1]
+ else:
+ y_pred_proba = y_pred
+ for met_name, met in metrics:
+ if met is not None:
+ if args.mode == 'regression' \
+ or met_name in ['accuracy', 'f1', 'precision', 'recall']:
+ pred_metrics_out.append({
+ "k": k, "metric": met_name, "metric_value": met(y_test_pred, y_pred)
+ })
+ else:
+ pred_metrics_out.append({
+ "k": k, "metric": met_name, "metric_value": met(y_test_pred, y_pred_proba)
+ })
+ metric_results["pred_metrics"] = copy.deepcopy(pd.DataFrame(pred_metrics_out))
+
+ # initialize results with metadata and results
+ kwargs: dict = model.kwargs # dict
+ for k in kwargs:
+ results[k].append(kwargs[k])
+ for k in fi_kwargs:
+ if k in fi_est.kwargs:
+ results[k].append(str(fi_est.kwargs[k]))
+ else:
+ results[k].append(None)
+ for met_name, met_val in metric_results.items():
+ results[met_name].append(met_val)
+ return results
+
+
+def run_comparison(path: str,
+ X, y,
+ estimators: List[ModelConfig],
+ fi_estimators: List[FIModelConfig],
+ args):
+ estimator_name = estimators[0].name.split(' - ')[0]
+ fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_estimators) \
+ if fi_estimator.model_type in estimators[0].model_type]
+ model_comparison_files_all = [oj(path, f'{estimator_name}_{fi_estimator.name}_comparisons.pkl') \
+ for fi_estimator in fi_estimators_all]
+ if args.parallel_id is not None:
+ model_comparison_files_all = [f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) \
+ for model_comparison_file in model_comparison_files_all]
+
+ fi_estimators = []
+ model_comparison_files = []
+ for model_comparison_file, fi_estimator in zip(model_comparison_files_all, fi_estimators_all):
+ if os.path.isfile(model_comparison_file) and not args.ignore_cache:
+ print(
+ f'{estimator_name} with {fi_estimator.name} results already computed and cached. use --ignore_cache to recompute')
+ else:
+ fi_estimators.append(fi_estimator)
+ model_comparison_files.append(model_comparison_file)
+
+ if len(fi_estimators) == 0:
+ return
+
+ results = compare_estimators(estimators=estimators,
+ fi_estimators=fi_estimators,
+ X=X, y=y,
+ args=args)
+
+ estimators_list = [e.name for e in estimators]
+
+ df = pd.DataFrame.from_dict(results)
+ df['split_seed'] = args.split_seed
+ if args.nosave_cols is not None:
+ nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")])
+ else:
+ nosave_cols = []
+ for col in nosave_cols:
+ if col in df.columns:
+ df = df.drop(columns=[col])
+
+ for model_comparison_file, fi_estimator in zip(model_comparison_files, fi_estimators):
+ output_dict = {
+ # metadata
+ 'sim_name': args.config,
+ 'estimators': estimators_list,
+ 'fi_estimators': fi_estimator.name,
+
+ # actual values
+ 'df': df.loc[df.fi == fi_estimator.name],
+ }
+ pkl.dump(output_dict, open(model_comparison_file, 'wb'))
+ return df
+
+
+def reformat_results(results):
+ results = results.reset_index().drop(columns=['index'])
+ fi_scores = pd.concat(results.pop('fi_scores').to_dict()). \
+ reset_index(level=0).rename(columns={'level_0': 'index'})
+ if "pred_metrics" in results.columns:
+ pred_metrics = pd.concat(results.pop('pred_metrics').to_dict()). \
+ reset_index(level=0).rename(columns={'level_0': 'index'})
+ else:
+ pred_metrics = None
+ results_df = pd.merge(results, fi_scores, left_index=True, right_on="index")
+ if pred_metrics is not None:
+ pred_results_df = pd.merge(results, pred_metrics, left_index=True, right_on="index")
+ else:
+ pred_results_df = None
+ return results_df, pred_results_df
+
+
+def get_metrics(mode: str = 'regression'):
+ if mode == 'binary_classification':
+ return [
+ ('rocauc', auroc_score),
+ ('prauc', auprc_score),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', f1_score),
+ ('recall', recall_score),
+ ('precision', precision_score),
+ ('avg_precision', average_precision_score)
+ ]
+ elif mode == 'multiclass_classification':
+ return [
+ ('rocauc', partial(auroc_score, multi_class="ovr")),
+ ('prauc', partial(auprc_score, multi_class="ovr")),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', partial(f1_score, average='micro')),
+ ('recall', partial(recall_score, average='micro')),
+ ('precision', partial(precision_score, average='micro'))
+ ]
+ elif mode == 'regression':
+ return [
+ ('r2', r2_score),
+ ('explained_variance', explained_variance_score),
+ ('mean_squared_error', mean_squared_error),
+ ('mean_absolute_error', mean_absolute_error),
+ ]
+
+
+def run_simulation(i, path, Xpath, ypath, ests, fi_ests, args):
+ X_df = pd.read_csv(Xpath)
+ y_df = pd.read_csv(ypath)
+ if args.response_idx is None:
+ keep_cols = y_df.columns
+ else:
+ keep_cols = [args.response_idx]
+ for col in keep_cols:
+ y = y_df[col].to_numpy().ravel()
+ keep_idx = ~pd.isnull(y)
+ X = X_df[keep_idx].to_numpy()
+ y = y[keep_idx]
+ if y_df.shape[1] > 1:
+ output_path = oj(path, col)
+ else:
+ output_path = path
+ os.makedirs(oj(output_path, "rep" + str(i)), exist_ok=True)
+ for est in ests:
+ for idx in range(len(est)):
+ if "random_state" in est[idx].kwargs.keys():
+ est[idx].kwargs["random_state"] = i
+ results = run_comparison(
+ path=oj(output_path, "rep" + str(i)),
+ X=X, y=y,
+ estimators=est,
+ fi_estimators=fi_ests,
+ args=args
+ )
+
+ return True
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ default_dir = os.getenv("SCRATCH")
+ if default_dir is not None:
+ default_dir = oj(default_dir, "feature_importance", "results")
+ else:
+ default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results')
+
+ parser.add_argument('--nreps', type=int, default=1)
+ parser.add_argument('--mode', type=str, default='binary_classification')
+ parser.add_argument('--model', type=str, default=None)
+ parser.add_argument('--fi_model', type=str, default=None)
+ parser.add_argument('--config', type=str, default='test_real_data')
+ parser.add_argument('--response_idx', type=str, default=None)
+ parser.add_argument('--nosave_cols', type=str, default=None)
+ parser.add_argument('--eval_top_ks', type=str, default="auto")
+
+ # for multiple reruns, should support varying split_seed
+ parser.add_argument('--ignore_cache', action='store_true', default=False)
+ parser.add_argument('--verbose', action='store_true', default=True)
+ parser.add_argument('--parallel', action='store_true', default=False)
+ parser.add_argument('--parallel_id', nargs='+', type=int, default=None)
+ parser.add_argument('--n_cores', type=int, default=None)
+ parser.add_argument('--split_seed', type=int, default=0)
+ parser.add_argument('--results_path', type=str, default=default_dir)
+
+ args = parser.parse_args()
+
+ assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'}
+
+ if args.parallel:
+ if args.n_cores is None:
+ print(os.getenv("SLURM_CPUS_ON_NODE"))
+ n_cores = int(os.getenv("SLURM_CPUS_ON_NODE"))
+ else:
+ n_cores = args.n_cores
+ client = Client(n_workers=n_cores)
+
+ ests, fi_ests, Xpath, ypath = fi_config.get_fi_configs(args.config, real_data=True)
+
+ if args.model:
+ ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests))
+ if args.fi_model:
+ fi_ests = list(filter(lambda x: args.fi_model.lower() == x[0].name.lower(), fi_ests))
+
+ if len(ests) == 0:
+ raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model)
+ if len(fi_ests) == 0:
+ raise ValueError('No valid FI estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model)
+ if args.verbose:
+ print('running', args.config,
+ 'ests', ests,
+ 'fi_ests', fi_ests)
+ print('\tsaving to', args.results_path)
+
+ results_dir = oj(args.results_path, args.config)
+ path = oj(results_dir, "seed" + str(args.split_seed))
+ os.makedirs(path, exist_ok=True)
+
+ eval_out = defaultdict(list)
+
+ if args.parallel:
+ futures = [dask.delayed(run_simulation)(i, path, Xpath, ypath, ests, fi_ests, args) for i in range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [run_simulation(i, path, Xpath, ypath, ests, fi_ests, args) for i in range(args.nreps)]
+ assert all(results)
+
+ print('completed all experiments successfully!')
+
+ # get model file names
+ model_comparison_files_all = []
+ for est in ests:
+ estimator_name = est[0].name.split(' - ')[0]
+ fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_ests) \
+ if fi_estimator.model_type in est[0].model_type]
+ model_comparison_files = [f'{estimator_name}_{fi_estimator.name}_comparisons.pkl' for fi_estimator in
+ fi_estimators_all]
+ model_comparison_files_all += model_comparison_files
+
+ # aggregate results
+ y_df = pd.read_csv(ypath)
+ results_list = []
+ for col in y_df.columns:
+ if y_df.shape[1] > 1:
+ output_path = oj(path, col)
+ else:
+ output_path = path
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(output_path, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(output_path, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+ results.insert(0, 'rep', i)
+ if y_df.shape[1] > 1:
+ results.insert(1, 'y_task', col)
+ results_list.append(results)
+
+ results_merged = pd.concat(results_list, axis=0)
+ pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb'))
+ results_df, pred_results_df = reformat_results(results_merged)
+ results_df.to_csv(oj(path, 'results.csv'), index=False)
+ if pred_results_df is not None:
+ pred_results_df.to_csv(oj(path, 'pred_results.csv'), index=False)
+
+ print('merged and saved all experiment results successfully!')
+
+# %%
diff --git a/feature_importance/03_run_prediction_simulations.py b/feature_importance/03_run_prediction_simulations.py
new file mode 100644
index 0000000..0cddec9
--- /dev/null
+++ b/feature_importance/03_run_prediction_simulations.py
@@ -0,0 +1,422 @@
+# Example usage: run in command line
+# cd feature_importance/
+# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache
+# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd
+
+import copy
+import os
+from os.path import join as oj
+import glob
+import argparse
+import pickle as pkl
+import time
+import warnings
+from scipy import stats
+import dask
+from dask.distributed import Client
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import sys
+from collections import defaultdict
+from typing import Callable, List, Tuple
+import itertools
+from functools import partial
+
+sys.path.append(".")
+sys.path.append("..")
+sys.path.append("../..")
+import fi_config
+from util import ModelConfig, apply_splitting_strategy, auroc_score, auprc_score
+
+from sklearn.metrics import accuracy_score, f1_score, recall_score, \
+ precision_score, average_precision_score, r2_score, explained_variance_score, \
+ mean_squared_error, mean_absolute_error, log_loss
+
+warnings.filterwarnings("ignore", message="Bins whose width")
+
+
+def compare_estimators(estimators: List[ModelConfig],
+ X, y,
+ metrics: List[Tuple[str, Callable]],
+ args, rep) -> Tuple[dict, dict]:
+ """Calculates results given estimators, feature importance estimators, and datasets.
+ Called in run_comparison
+ """
+ if type(estimators) != list:
+ raise Exception("First argument needs to be a list of Models")
+ if type(metrics) != list:
+ raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs")
+
+ # initialize results
+ results = defaultdict(lambda: [])
+
+ if args.splitting_strategy is not None:
+ X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(
+ X, y, args.splitting_strategy, args.split_seed + rep)
+ else:
+ X_train = X
+ X_tune = X
+ X_test = X
+ y_train = y
+ y_tune = y
+ y_test = y
+
+ # loop over model estimators
+ for model in tqdm(estimators, leave=False):
+ est = model.cls(**model.kwargs)
+
+ start = time.time()
+ est.fit(X_train, y_train)
+ end = time.time()
+
+ metric_results = {'model': model.name}
+ y_pred = est.predict(X_test)
+ if args.mode != 'regression':
+ y_pred_proba = est.predict_proba(X_test)
+ if args.mode == 'binary_classification':
+ y_pred_proba = y_pred_proba[:, 1]
+ else:
+ y_pred_proba = y_pred
+ for met_name, met in metrics:
+ if met is not None:
+ if args.mode == 'regression' \
+ or met_name in ['accuracy', 'f1', 'precision', 'recall']:
+ metric_results[met_name] = met(y_test, y_pred)
+ else:
+ metric_results[met_name] = met(y_test, y_pred_proba)
+ metric_results['predictions'] = copy.deepcopy(pd.DataFrame(y_pred_proba))
+ metric_results['time'] = end - start
+
+ # initialize results with metadata and metric results
+ kwargs: dict = model.kwargs # dict
+ for k in kwargs:
+ results[k].append(kwargs[k])
+ for met_name, met_val in metric_results.items():
+ results[met_name].append(met_val)
+
+ return results
+
+
+def run_comparison(rep: int,
+ path: str,
+ X, y,
+ metrics: List[Tuple[str, Callable]],
+ estimators: List[ModelConfig],
+ args):
+ estimator_name = estimators[0].name.split(' - ')[0]
+ model_comparison_file = oj(path, f'{estimator_name}_comparisons.pkl')
+ if args.parallel_id is not None:
+ model_comparison_file = f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.'))
+
+ if os.path.isfile(model_comparison_file) and not args.ignore_cache:
+ print(f'{estimator_name} results already computed and cached. use --ignore_cache to recompute')
+ return
+
+ results = compare_estimators(estimators=estimators,
+ X=X, y=y,
+ metrics=metrics,
+ args=args,
+ rep=rep)
+
+ estimators_list = [e.name for e in estimators]
+
+ df = pd.DataFrame.from_dict(results)
+ df['split_seed'] = args.split_seed + rep
+ if args.nosave_cols is not None:
+ nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")])
+ else:
+ nosave_cols = []
+ for col in nosave_cols:
+ if col in df.columns:
+ df = df.drop(columns=[col])
+
+ output_dict = {
+ # metadata
+ 'sim_name': args.config,
+ 'estimators': estimators_list,
+
+ # actual values
+ 'df': df,
+ }
+ pkl.dump(output_dict, open(model_comparison_file, 'wb'))
+
+ return df
+
+
+def reformat_results(results):
+ results = results.reset_index().drop(columns=['index'])
+ predictions = pd.concat(results.pop('predictions').to_dict()). \
+ reset_index(level=[0, 1]).rename(columns={'level_0': 'index', 'level_1': 'sample_id'})
+ results_df = pd.merge(results, predictions, left_index=True, right_on="index")
+ return results_df
+
+
+def get_metrics(mode: str = 'regression'):
+ if mode == 'binary_classification':
+ return [
+ ('rocauc', auroc_score),
+ ('prauc', auprc_score),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', f1_score),
+ ('recall', recall_score),
+ ('precision', precision_score),
+ ('avg_precision', average_precision_score)
+ ]
+ elif mode == 'multiclass_classification':
+ return [
+ ('rocauc', partial(auroc_score, multi_class="ovr")),
+ ('prauc', partial(auprc_score, multi_class="ovr")),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', partial(f1_score, average='micro')),
+ ('recall', partial(recall_score, average='micro')),
+ ('precision', partial(precision_score, average='micro'))
+ ]
+ elif mode == 'regression':
+ return [
+ ('r2', r2_score),
+ ('explained_variance', explained_variance_score),
+ ('mean_squared_error', mean_squared_error),
+ ('mean_absolute_error', mean_absolute_error),
+ ]
+
+
+def run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, metrics, args):
+ os.makedirs(oj(path, val_name, "rep" + str(i)), exist_ok=True)
+ np.random.seed(i)
+ max_iter = 100
+ iter = 0
+ while iter <= max_iter: # regenerate data if y is constant
+ X = X_dgp(**X_params_dict)
+ y, support, beta = y_dgp(X, **y_params_dict, return_support=True)
+ if not all(y == y[0]):
+ break
+ iter += 1
+ if iter > max_iter:
+ raise ValueError("Response y is constant.")
+ if args.omit_vars is not None:
+ omit_vars = np.unique([int(x.strip()) for x in args.omit_vars.split(",")])
+ support = np.delete(support, omit_vars)
+ X = np.delete(X, omit_vars, axis=1)
+ del beta # note: beta is not currently supported when using omit_vars
+
+ for est in ests:
+ results = run_comparison(rep=i,
+ path=oj(path, val_name, "rep" + str(i)),
+ X=X, y=y,
+ metrics=metrics,
+ estimators=est,
+ args=args)
+ return True
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ default_dir = os.getenv("SCRATCH")
+ if default_dir is not None:
+ default_dir = oj(default_dir, "feature_importance", "results")
+ else:
+ default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results')
+
+ parser.add_argument('--nreps', type=int, default=2)
+ parser.add_argument('--mode', type=str, default='regression')
+ parser.add_argument('--model', type=str, default=None)
+ parser.add_argument('--config', type=str, default='mdi_plus.prediction_sims.ccle_rnaseq_regression-')
+ parser.add_argument('--omit_vars', type=str, default=None) # comma-separated string of variables to omit
+ parser.add_argument('--nosave_cols', type=str, default=None)
+
+ # for multiple reruns, should support varying split_seed
+ parser.add_argument('--ignore_cache', action='store_true', default=False)
+ parser.add_argument('--splitting_strategy', type=str, default="train-test")
+ parser.add_argument('--verbose', action='store_true', default=True)
+ parser.add_argument('--parallel', action='store_true', default=False)
+ parser.add_argument('--parallel_id', nargs='+', type=int, default=None)
+ parser.add_argument('--n_cores', type=int, default=None)
+ parser.add_argument('--split_seed', type=int, default=0)
+ parser.add_argument('--results_path', type=str, default=default_dir)
+
+ args = parser.parse_args()
+
+ assert args.splitting_strategy in {
+ 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata'}
+ assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'}
+
+ if args.parallel:
+ if args.n_cores is None:
+ print(os.getenv("SLURM_CPUS_ON_NODE"))
+ n_cores = int(os.getenv("SLURM_CPUS_ON_NODE"))
+ else:
+ n_cores = args.n_cores
+ client = Client(n_workers=n_cores)
+
+ ests, fi_ests, \
+ X_dgp, X_params_dict, y_dgp, y_params_dict, \
+ vary_param_name, vary_param_vals = fi_config.get_fi_configs(args.config)
+ metrics = get_metrics(args.mode)
+
+ if args.model:
+ ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests))
+
+ if len(ests) == 0:
+ raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model)
+ if args.verbose:
+ print('running', args.config,
+ 'ests', ests)
+ print('\tsaving to', args.results_path)
+
+ if args.omit_vars is not None:
+ results_dir = oj(args.results_path, args.config + "_omitted_vars")
+ else:
+ results_dir = oj(args.results_path, args.config)
+
+ if isinstance(vary_param_name, list):
+ path = oj(results_dir, "varying_" + "_".join(vary_param_name), "seed" + str(args.split_seed))
+ else:
+ path = oj(results_dir, "varying_" + vary_param_name, "seed" + str(args.split_seed))
+ os.makedirs(path, exist_ok=True)
+
+ eval_out = defaultdict(list)
+
+ vary_type = None
+ if isinstance(vary_param_name, list): # multiple parameters are being varied
+ # get parameters that are being varied over and identify whether it's a DGP/method/fi_method argument
+ keys, values = zip(*vary_param_vals.items())
+ vary_param_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
+ vary_type = {}
+ for vary_param_dict in vary_param_dicts:
+ for param_name, param_val in vary_param_dict.items():
+ if param_name in X_params_dict.keys() and param_name in y_params_dict.keys():
+ raise ValueError('Cannot vary over parameter in both X and y DGPs.')
+ elif param_name in X_params_dict.keys():
+ vary_type[param_name] = "dgp"
+ X_params_dict[param_name] = vary_param_vals[param_name][param_val]
+ elif param_name in y_params_dict.keys():
+ vary_type[param_name] = "dgp"
+ y_params_dict[param_name] = vary_param_vals[param_name][param_val]
+ else:
+ est_kwargs = list(
+ itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))]))
+ if param_name in est_kwargs:
+ vary_type[param_name] = "est"
+ else:
+ raise ValueError('Invalid vary_param_name.')
+
+ if args.parallel:
+ futures = [
+ dask.delayed(run_simulation)(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp,
+ y_params_dict, y_dgp, ests, metrics, args) for i in
+ range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [
+ run_simulation(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, y_params_dict,
+ y_dgp, ests, metrics, args) for i in range(args.nreps)]
+ assert all(results)
+
+ else: # only on parameter is being varied over
+ # get parameter that is being varied over and identify whether it's a DGP/method/fi_method argument
+ for val_name, val in vary_param_vals.items():
+ if vary_param_name in X_params_dict.keys() and vary_param_name in y_params_dict.keys():
+ raise ValueError('Cannot vary over parameter in both X and y DGPs.')
+ elif vary_param_name in X_params_dict.keys():
+ vary_type = "dgp"
+ X_params_dict[vary_param_name] = val
+ elif vary_param_name in y_params_dict.keys():
+ vary_type = "dgp"
+ y_params_dict[vary_param_name] = val
+ else:
+ est_kwargs = list(itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))]))
+ if vary_param_name in est_kwargs:
+ vary_type = "est"
+ else:
+ raise ValueError('Invalid vary_param_name.')
+
+ if args.parallel:
+ futures = [
+ dask.delayed(run_simulation)(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests,
+ metrics, args) for i in range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests,
+ metrics, args) for i in range(args.nreps)]
+ assert all(results)
+
+ print('completed all experiments successfully!')
+
+ # get model file names
+ model_comparison_files_all = []
+ for est in ests:
+ estimator_name = est[0].name.split(' - ')[0]
+ model_comparison_file = f'{estimator_name}_comparisons.pkl'
+ model_comparison_files_all.append(model_comparison_file)
+
+ # aggregate results
+ # aggregate results
+ results_list = []
+ if isinstance(vary_param_name, list):
+ for vary_param_dict in vary_param_dicts:
+ val_name = "_".join(vary_param_dict.values())
+
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(path, val_name, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+
+ for param_name, param_val in vary_param_dict.items():
+ val = vary_param_vals[param_name][param_val]
+ if vary_type[param_name] == "dgp":
+ if np.isscalar(val):
+ results.insert(0, param_name, val)
+ else:
+ results.insert(0, param_name, [val for i in range(results.shape[0])])
+ results.insert(1, param_name + "_name", param_val)
+ elif vary_type[param_name] == "est":
+ results.insert(0, param_name + "_name", copy.deepcopy(results[param_name]))
+ results.insert(0, 'rep', i)
+ results_list.append(results)
+ else:
+ for val_name, val in vary_param_vals.items():
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(path, val_name, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+ if vary_type == "dgp":
+ if np.isscalar(val):
+ results.insert(0, vary_param_name, val)
+ else:
+ results.insert(0, vary_param_name, [val for i in range(results.shape[0])])
+ results.insert(1, vary_param_name + "_name", val_name)
+ results.insert(2, 'rep', i)
+ elif vary_type == "est":
+ results.insert(0, vary_param_name + "_name", copy.deepcopy(results[vary_param_name]))
+ results.insert(1, 'rep', i)
+ results_list.append(results)
+
+ results_merged = pd.concat(results_list, axis=0)
+ pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb'))
+ results_df = reformat_results(results_merged)
+ results_df.to_csv(oj(path, 'results.csv'), index=False)
+
+ print('merged and saved all experiment results successfully!')
+
+# %%
diff --git a/feature_importance/04_run_prediction_real_data.py b/feature_importance/04_run_prediction_real_data.py
new file mode 100644
index 0000000..bb41e3c
--- /dev/null
+++ b/feature_importance/04_run_prediction_real_data.py
@@ -0,0 +1,333 @@
+# Example usage: run in command line
+# cd feature_importance/
+# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache
+# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd
+
+import copy
+import os
+from os.path import join as oj
+import glob
+import argparse
+import pickle as pkl
+import time
+import warnings
+from scipy import stats
+import dask
+from dask.distributed import Client
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import sys
+from collections import defaultdict
+from typing import Callable, List, Tuple
+import itertools
+from functools import partial
+
+sys.path.append(".")
+sys.path.append("..")
+sys.path.append("../..")
+import fi_config
+from util import ModelConfig, apply_splitting_strategy, auroc_score, auprc_score
+
+from sklearn.metrics import accuracy_score, f1_score, recall_score, \
+ precision_score, average_precision_score, r2_score, explained_variance_score, \
+ mean_squared_error, mean_absolute_error, log_loss
+
+warnings.filterwarnings("ignore", message="Bins whose width")
+
+
+def compare_estimators(estimators: List[ModelConfig],
+ X, y,
+ metrics: List[Tuple[str, Callable]],
+ args, rep) -> Tuple[dict, dict]:
+ """Calculates results given estimators, feature importance estimators, and datasets.
+ Called in run_comparison
+ """
+ if type(estimators) != list:
+ raise Exception("First argument needs to be a list of Models")
+ if type(metrics) != list:
+ raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs")
+
+ # initialize results
+ results = defaultdict(lambda: [])
+
+ if args.splitting_strategy is not None:
+ X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(
+ X, y, args.splitting_strategy, args.split_seed + rep)
+ else:
+ X_train = X
+ X_tune = X
+ X_test = X
+ y_train = y
+ y_tune = y
+ y_test = y
+
+ # loop over model estimators
+ for model in tqdm(estimators, leave=False):
+ est = model.cls(**model.kwargs)
+
+ start = time.time()
+ est.fit(X_train, y_train)
+ end = time.time()
+
+ metric_results = {'model': model.name}
+ y_pred = est.predict(X_test)
+ if args.mode != 'regression':
+ y_pred_proba = est.predict_proba(X_test)
+ if args.mode == 'binary_classification':
+ y_pred_proba = y_pred_proba[:, 1]
+ else:
+ y_pred_proba = y_pred
+ for met_name, met in metrics:
+ if met is not None:
+ if args.mode == 'regression' \
+ or met_name in ['accuracy', 'f1', 'precision', 'recall']:
+ metric_results[met_name] = met(y_test, y_pred)
+ else:
+ metric_results[met_name] = met(y_test, y_pred_proba)
+ metric_results['predictions'] = copy.deepcopy(pd.DataFrame(y_pred_proba))
+ metric_results['time'] = end - start
+
+ # initialize results with metadata and metric results
+ kwargs: dict = model.kwargs # dict
+ for k in kwargs:
+ results[k].append(kwargs[k])
+ for met_name, met_val in metric_results.items():
+ results[met_name].append(met_val)
+
+ return results
+
+
+def run_comparison(rep: int,
+ path: str,
+ X, y,
+ metrics: List[Tuple[str, Callable]],
+ estimators: List[ModelConfig],
+ args):
+ estimator_name = estimators[0].name.split(' - ')[0]
+ model_comparison_file = oj(path, f'{estimator_name}_comparisons.pkl')
+ if args.parallel_id is not None:
+ model_comparison_file = f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.'))
+
+ if os.path.isfile(model_comparison_file) and not args.ignore_cache:
+ print(f'{estimator_name} results already computed and cached. use --ignore_cache to recompute')
+ return
+
+ results = compare_estimators(estimators=estimators,
+ X=X, y=y,
+ metrics=metrics,
+ args=args,
+ rep=rep)
+
+ estimators_list = [e.name for e in estimators]
+
+ df = pd.DataFrame.from_dict(results)
+ df['split_seed'] = args.split_seed + rep
+ if args.nosave_cols is not None:
+ nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")])
+ else:
+ nosave_cols = []
+ for col in nosave_cols:
+ if col in df.columns:
+ df = df.drop(columns=[col])
+
+ output_dict = {
+ # metadata
+ 'sim_name': args.config,
+ 'estimators': estimators_list,
+
+ # actual values
+ 'df': df,
+ }
+ pkl.dump(output_dict, open(model_comparison_file, 'wb'))
+
+ return df
+
+
+def reformat_results(results):
+ results = results.reset_index().drop(columns=['index'])
+ predictions = pd.concat(results.pop('predictions').to_dict()). \
+ reset_index(level=[0, 1]).rename(columns={'level_0': 'index', 'level_1': 'sample_id'})
+ results_df = pd.merge(results, predictions, left_index=True, right_on="index")
+ return results_df
+
+
+def get_metrics(mode: str = 'regression'):
+ if mode == 'binary_classification':
+ return [
+ ('rocauc', auroc_score),
+ ('prauc', auprc_score),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', f1_score),
+ ('recall', recall_score),
+ ('precision', precision_score),
+ ('avg_precision', average_precision_score)
+ ]
+ elif mode == 'multiclass_classification':
+ return [
+ ('rocauc', partial(auroc_score, multi_class="ovr")),
+ ('prauc', partial(auprc_score, multi_class="ovr")),
+ ('logloss', log_loss),
+ ('accuracy', accuracy_score),
+ ('f1', partial(f1_score, average='micro')),
+ ('recall', partial(recall_score, average='micro')),
+ ('precision', partial(precision_score, average='micro'))
+ ]
+ elif mode == 'regression':
+ return [
+ ('r2', r2_score),
+ ('explained_variance', explained_variance_score),
+ ('mean_squared_error', mean_squared_error),
+ ('mean_absolute_error', mean_absolute_error),
+ ]
+
+
+def run_simulation(i, path, Xpath, ypath, ests, metrics, args):
+ X_df = pd.read_csv(Xpath)
+ y_df = pd.read_csv(ypath)
+ if args.subsample_n is not None:
+ if args.subsample_n < X_df.shape[0]:
+ keep_rows = np.random.choice(X_df.shape[0], args.subsample_n, replace=False)
+ X_df = X_df.iloc[keep_rows]
+ y_df = y_df.iloc[keep_rows]
+ if args.response_idx is None:
+ keep_cols = y_df.columns
+ else:
+ keep_cols = [args.response_idx]
+ for col in keep_cols:
+ y = y_df[col].to_numpy().ravel()
+ keep_idx = ~pd.isnull(y)
+ X = X_df[keep_idx].to_numpy()
+ y = y[keep_idx]
+ if y_df.shape[1] > 1:
+ output_path = oj(path, col)
+ else:
+ output_path = path
+ os.makedirs(oj(output_path, "rep" + str(i)), exist_ok=True)
+ for est in ests:
+ for idx in range(len(est)):
+ if "random_state" in est[idx].kwargs.keys():
+ est[idx].kwargs["random_state"] = i
+ results = run_comparison(
+ rep=i,
+ path=oj(output_path, "rep" + str(i)),
+ X=X, y=y,
+ metrics=metrics,
+ estimators=est,
+ args=args
+ )
+
+ return True
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+
+ default_dir = os.getenv("SCRATCH")
+ if default_dir is not None:
+ default_dir = oj(default_dir, "feature_importance", "results")
+ else:
+ default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results')
+
+ parser.add_argument('--nreps', type=int, default=2)
+ parser.add_argument('--mode', type=str, default='regression')
+ parser.add_argument('--model', type=str, default=None)
+ parser.add_argument('--config', type=str, default='mdi_plus.prediction_sims.ccle_rnaseq_regression-')
+ parser.add_argument('--response_idx', type=str, default=None)
+ parser.add_argument('--subsample_n', type=int, default=None)
+ parser.add_argument('--nosave_cols', type=str, default=None)
+
+ # for multiple reruns, should support varying split_seed
+ parser.add_argument('--ignore_cache', action='store_true', default=False)
+ parser.add_argument('--splitting_strategy', type=str, default="train-test")
+ parser.add_argument('--verbose', action='store_true', default=True)
+ parser.add_argument('--parallel', action='store_true', default=False)
+ parser.add_argument('--parallel_id', nargs='+', type=int, default=None)
+ parser.add_argument('--n_cores', type=int, default=None)
+ parser.add_argument('--split_seed', type=int, default=0)
+ parser.add_argument('--results_path', type=str, default=default_dir)
+
+ args = parser.parse_args()
+
+ assert args.splitting_strategy in {
+ 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata'}
+ assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'}
+
+ if args.parallel:
+ if args.n_cores is None:
+ print(os.getenv("SLURM_CPUS_ON_NODE"))
+ n_cores = int(os.getenv("SLURM_CPUS_ON_NODE"))
+ else:
+ n_cores = args.n_cores
+ client = Client(n_workers=n_cores)
+
+ ests, _, Xpath, ypath = fi_config.get_fi_configs(args.config, real_data=True)
+ metrics = get_metrics(args.mode)
+
+ if args.model:
+ ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests))
+
+ if len(ests) == 0:
+ raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model)
+ if args.verbose:
+ print('running', args.config,
+ 'ests', ests)
+ print('\tsaving to', args.results_path)
+
+ results_dir = oj(args.results_path, args.config)
+ path = oj(results_dir, "seed" + str(args.split_seed))
+ os.makedirs(path, exist_ok=True)
+
+ eval_out = defaultdict(list)
+
+ if args.parallel:
+ futures = [dask.delayed(run_simulation)(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)]
+ results = dask.compute(*futures)
+ else:
+ results = [run_simulation(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)]
+ assert all(results)
+
+ print('completed all experiments successfully!')
+
+ # get model file names
+ model_comparison_files_all = []
+ for est in ests:
+ estimator_name = est[0].name.split(' - ')[0]
+ model_comparison_file = f'{estimator_name}_comparisons.pkl'
+ model_comparison_files_all.append(model_comparison_file)
+
+ # aggregate results
+ y_df = pd.read_csv(ypath)
+ results_list = []
+ for col in y_df.columns:
+ if y_df.shape[1] > 1:
+ output_path = oj(path, col)
+ else:
+ output_path = path
+ for i in range(args.nreps):
+ all_files = glob.glob(oj(output_path, 'rep' + str(i), '*'))
+ model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all])
+
+ if len(model_files) == 0:
+ print('No files found at ', oj(output_path, 'rep' + str(i)))
+ continue
+
+ results = pd.concat(
+ [pkl.load(open(f, 'rb'))['df'] for f in model_files],
+ axis=0
+ )
+ results.insert(0, 'rep', i)
+ if y_df.shape[1] > 1:
+ results.insert(1, 'y_task', col)
+ results_list.append(results)
+
+ results_merged = pd.concat(results_list, axis=0)
+ pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb'))
+ results_df = reformat_results(results_merged)
+ results_df.to_csv(oj(path, 'results.csv'), index=False)
+
+ print('merged and saved all experiment results successfully!')
+
+# %%
diff --git a/feature_importance/__init__.py b/feature_importance/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/__init__.py b/feature_importance/fi_config/__init__.py
new file mode 100644
index 0000000..c4c5798
--- /dev/null
+++ b/feature_importance/fi_config/__init__.py
@@ -0,0 +1,13 @@
+import importlib
+
+def get_fi_configs(config_name, real_data=False):
+ if real_data:
+ ests = importlib.import_module(f'fi_config.{config_name}.models')
+ dgp = importlib.import_module(f'fi_config.{config_name}.dgp')
+ return ests.ESTIMATORS, ests.FI_ESTIMATORS, dgp.X_PATH, dgp.Y_PATH
+ else:
+ ests = importlib.import_module(f'fi_config.{config_name}.models')
+ dgp = importlib.import_module(f'fi_config.{config_name}.dgp')
+ return ests.ESTIMATORS, ests.FI_ESTIMATORS, \
+ dgp.X_DGP, dgp.X_PARAMS_DICT, dgp.Y_DGP, dgp.Y_PARAMS_DICT, \
+ dgp.VARY_PARAM_NAME, dgp.VARY_PARAM_VALS
\ No newline at end of file
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5e9f5e6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..0643de9
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = logistic_partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "s":1,
+ "m":3,
+ "r":2,
+ "tau":0,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py
new file mode 100644
index 0000000..babf294
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5874330
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = logistic_lss_model
+Y_PARAMS_DICT = {
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "beta": 2,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..16536f5
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..ab616ed
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = logistic_partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "s":1,
+ "m":3,
+ "r":2,
+ "tau":0,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py
new file mode 100644
index 0000000..9e43dad
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..e3646d1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = logistic_lss_model
+Y_PARAMS_DICT = {
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "beta": 2,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..ec4e6c5
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..d88c1fa
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "s":1,
+ "m":3,
+ "r":2,
+ "tau":0,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py
new file mode 100644
index 0000000..28b92fd
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5688e4a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_lss_model
+Y_PARAMS_DICT = {
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "beta": 2,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..0325f07
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..3caf738
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "s":1,
+ "m":3,
+ "r":2,
+ "tau":0,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5db817e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5df43a0
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_lss_model
+Y_PARAMS_DICT = {
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "beta": 2,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/copy_models_config.sh
new file mode 100644
index 0000000..d6beab6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/copy_models_config.sh
@@ -0,0 +1,29 @@
+echo "Regression DGPs..."
+for f in $(find . -type d -name "*_dgp" -o -name "*regression" ! -name "*logistic_dgp" ! -name "*classification" ! -name "*robust_dgp" ! -name "*-"); do
+ echo "$f"
+ cp models_regression.py "$f"/models.py
+done
+
+echo "Classification DGPs..."
+for f in $(find . -type d -name "*logistic_dgp" -o -name "*classification" ! -name "*-"); do
+ echo "$f"
+ cp models_classification.py "$f"/models.py
+done
+
+echo "Robust DGPs..."
+for f in $(find . -type d -name "*robust_dgp" ! -name "*-"); do
+ echo "$f"
+ cp models_robust.py "$f"/models.py
+done
+
+echo "Bias Sims..."
+f=mdi_bias_sims/entropy_sims/linear_dgp
+echo "$f"
+cp models_bias.py "$f"/models.py
+f=mdi_bias_sims/entropy_sims/logistic_dgp
+echo "$f"
+cp models_bias.py "$f"/models.py
+f=mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp
+echo "$f"
+cp models_bias.py "$f"/models.py
+
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh
new file mode 100644
index 0000000..5ea6241
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh
@@ -0,0 +1,5 @@
+echo "Modeling Choices Classification DGPs..."
+for f in $(find . -type d -name "*_dgp"); do
+ echo "$f"
+ cp models.py "$f"/models.py
+done
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..ec4e6c5
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..9dfc2e4
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+from sklearn.metrics import roc_auc_score
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py
new file mode 100644
index 0000000..28b92fd
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": None
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py
new file mode 100644
index 0000000..9dfc2e4
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+from sklearn.metrics import roc_auc_score
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py
new file mode 100644
index 0000000..9dfc2e4
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+from sklearn.metrics import roc_auc_score
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py
new file mode 100644
index 0000000..0325f07
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_hier_model
+Y_PARAMS_DICT = {
+ "m":3,
+ "r":2,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py
new file mode 100644
index 0000000..9dfc2e4
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+from sklearn.metrics import roc_auc_score
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py
new file mode 100644
index 0000000..5db817e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1500,
+ "sample_col_n": 100
+}
+Y_DGP = logistic_model
+Y_PARAMS_DICT = {
+ "s": 5,
+ "beta": 1,
+ "frac_label_corruption": None
+}
+
+VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"]
+VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py
new file mode 100644
index 0000000..9dfc2e4
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+from sklearn.metrics import roc_auc_score
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})],
+ [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c57710e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..e454519
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,22 @@
+import copy
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})]
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..193daff
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..e454519
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,22 @@
+import copy
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})]
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh
new file mode 100644
index 0000000..9c837ca
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh
@@ -0,0 +1,5 @@
+echo "Modeling Choices Regression DGPs..."
+for f in $(find . -type d -name "*_dgp"); do
+ echo "$f"
+ cp models.py "$f"/models.py
+done
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..f328986
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": int(100 * 1.5), "250": int(250 * 1.5), "500": int(500 * 1.5), "1000": int(1000 * 1.5), "1500": int(1500 * 1.5)}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..e454519
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,22 @@
+import copy
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})]
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py
new file mode 100644
index 0000000..0bed41f
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": int(100 * 1.5), "250": int(250 * 1.5), "500": int(500 * 1.5), "1000": int(1000 * 1.5), "1500": int(1500 * 1.5)}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py
new file mode 100644
index 0000000..e454519
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py
@@ -0,0 +1,22 @@
+import copy
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})]
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py
new file mode 100644
index 0000000..e454519
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py
@@ -0,0 +1,22 @@
+import copy
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})]
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c57710e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..1cb8ff1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,23 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+from sklearn.metrics import mean_absolute_error, r2_score
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False,
+ other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')},
+ "lasso": {"prediction_model": LassoRegressorPPM()},
+ "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..193daff
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..1cb8ff1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,23 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+from sklearn.metrics import mean_absolute_error, r2_score
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False,
+ other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')},
+ "lasso": {"prediction_model": LassoRegressorPPM()},
+ "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh
new file mode 100644
index 0000000..9c837ca
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh
@@ -0,0 +1,5 @@
+echo "Modeling Choices Regression DGPs..."
+for f in $(find . -type d -name "*_dgp"); do
+ echo "$f"
+ cp models.py "$f"/models.py
+done
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..5e3b23d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..1cb8ff1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,23 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+from sklearn.metrics import mean_absolute_error, r2_score
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False,
+ other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')},
+ "lasso": {"prediction_model": LassoRegressorPPM()},
+ "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py
new file mode 100644
index 0000000..d09d33a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py
new file mode 100644
index 0000000..1cb8ff1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py
@@ -0,0 +1,23 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+from sklearn.metrics import mean_absolute_error, r2_score
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False,
+ other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')},
+ "lasso": {"prediction_model": LassoRegressorPPM()},
+ "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py
new file mode 100644
index 0000000..1cb8ff1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py
@@ -0,0 +1,23 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+from sklearn.metrics import mean_absolute_error, r2_score
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})],
+ [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False,
+ other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')},
+ "lasso": {"prediction_model": LassoRegressorPPM()},
+ "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py
new file mode 100644
index 0000000..9b2fb69
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py
@@ -0,0 +1,29 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+n = 250
+d = 100
+n_correlated = 50
+
+X_DGP = sample_block_cor_X
+X_PARAMS_DICT = {
+ "n": n,
+ "d": d,
+ "rho": [0.8] + [0 for i in range(int(d / n_correlated - 1))],
+ "n_blocks": int(d / n_correlated)
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1,
+ "m": 3,
+ "r": 2,
+ "tau": 0
+}
+
+VARY_PARAM_NAME = ["heritability", "rho"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.4": 0.4},
+ "rho": {"0.5": [0.5] + [0 for i in range(int(d / n_correlated - 1))], "0.6": [0.6] + [0 for i in range(int(d / n_correlated - 1))], "0.7": [0.7] + [0 for i in range(int(d / n_correlated - 1))], "0.8": [0.8] + [0 for i in range(int(d / n_correlated - 1))], "0.9": [0.9] + [0 for i in range(int(d / n_correlated - 1))], "0.99": [0.99] + [0 for i in range(int(d / n_correlated - 1))]}}
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py
new file mode 100644
index 0000000..cbfbc5b
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py
new file mode 100644
index 0000000..9cadad5
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py
@@ -0,0 +1,20 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+
+X_DGP = entropy_X
+X_PARAMS_DICT = {
+ "n": 100
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1
+}
+
+VARY_PARAM_NAME = ["heritability", "n"]
+VARY_PARAM_VALS = {"heritability": {"0.05": 0.05, "0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.6": 0.6, "0.8": 0.8},
+ "n": {"50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000}}
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py
new file mode 100644
index 0000000..cbfbc5b
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py
new file mode 100644
index 0000000..380d0ef
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py
@@ -0,0 +1,17 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+
+X_DGP = entropy_X
+X_PARAMS_DICT = {
+ "n": 100
+}
+Y_DGP = entropy_y
+Y_PARAMS_DICT = {
+ "c": 3
+}
+
+VARY_PARAM_NAME = ["c", "n"]
+VARY_PARAM_VALS = {"c": {"3": 3},
+ "n": {"50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000}}
diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py
new file mode 100644
index 0000000..f6c0b7d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py
@@ -0,0 +1,24 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(gcv_mode='eigen'), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_ridge_inbag', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(loo=False, gcv_mode='eigen'), 'scoring_fns': _fast_r2_score, "sample_split": "inbag"})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_logistic_logloss_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": LogisticClassifierPPM(loo=False)})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
+
diff --git a/feature_importance/fi_config/mdi_plus/models_bias.py b/feature_importance/fi_config/mdi_plus/models_bias.py
new file mode 100644
index 0000000..cbfbc5b
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/models_bias.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})],
+ [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/models_classification.py b/feature_importance/fi_config/mdi_plus/models_classification.py
new file mode 100644
index 0000000..45c16b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/models_classification.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/models_regression.py b/feature_importance/fi_config/mdi_plus/models_regression.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/models_regression.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/models_robust.py b/feature_importance/fi_config/mdi_plus/models_robust.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/models_robust.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c57710e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..6071e78
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..193daff
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..6071e78
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh
new file mode 100644
index 0000000..9c837ca
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh
@@ -0,0 +1,5 @@
+echo "Modeling Choices Regression DGPs..."
+for f in $(find . -type d -name "*_dgp"); do
+ echo "$f"
+ cp models.py "$f"/models.py
+done
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..5e3b23d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..6071e78
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py
new file mode 100644
index 0000000..d09d33a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py
new file mode 100644
index 0000000..6071e78
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py
new file mode 100644
index 0000000..6071e78
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c57710e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..dffdf47
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..193daff
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..dffdf47
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh
new file mode 100644
index 0000000..9c837ca
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh
@@ -0,0 +1,5 @@
+echo "Modeling Choices Regression DGPs..."
+for f in $(find . -type d -name "*_dgp"); do
+ echo "$f"
+ cp models.py "$f"/models.py
+done
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..5e3b23d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..dffdf47
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py
new file mode 100644
index 0000000..d09d33a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py
new file mode 100644
index 0000000..dffdf47
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py
new file mode 100644
index 0000000..dffdf47
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py
@@ -0,0 +1,18 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.ppms import RidgeRegressorPPM
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})],
+ [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..064576e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_col_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..cc56c30
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_col_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..a54aec8
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_col_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..70768a7
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_col_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..abbd688
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py
new file mode 100644
index 0000000..99ae392
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "s"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "s": {"1": 1, "5": 5, "10": 10, "25": 25, "50": 50}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c23b7c2
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py
new file mode 100644
index 0000000..8d1c8b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py
@@ -0,0 +1 @@
+
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..429916e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..69e810c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py
new file mode 100644
index 0000000..e049de9
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "s"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "s": {"1": 1, "5": 5, "10": 10, "25": 25, "50": 50}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..ab525ee
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..9611cb6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "m"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}}
diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py
new file mode 100644
index 0000000..b5860cd
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv"
+Y_PATH = "data/y_ccle_rnaseq.csv"
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py
new file mode 100644
index 0000000..dfe1a3e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py
@@ -0,0 +1,24 @@
+import copy
+import numpy as np
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusRegressor
+from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM
+
+
+rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=42)
+ridge_model = RidgeRegressorPPM()
+lasso_model = LassoRegressorPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})],
+ [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(lasso_model)})],
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py
new file mode 100644
index 0000000..89ec3e0
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_enhancer_all.csv"
+Y_PATH = "data/y_enhancer.csv"
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py
new file mode 100644
index 0000000..8ba5961
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py
@@ -0,0 +1,26 @@
+import copy
+import numpy as np
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV
+from sklearn.utils.extmath import softmax
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusClassifier
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+
+
+rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42)
+ridge_model = RidgeClassifierPPM()
+logistic_model = LogisticClassifierPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})],
+ [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(logistic_model)})],
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py
new file mode 100644
index 0000000..79ab2f0
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_juvenile_cleaned.csv"
+Y_PATH = "data/y_juvenile.csv"
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py
new file mode 100644
index 0000000..8ba5961
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py
@@ -0,0 +1,26 @@
+import copy
+import numpy as np
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV
+from sklearn.utils.extmath import softmax
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusClassifier
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+
+
+rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42)
+ridge_model = RidgeClassifierPPM()
+logistic_model = LogisticClassifierPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})],
+ [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(logistic_model)})],
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py
new file mode 100644
index 0000000..079b2da
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_splicing_cleaned.csv"
+Y_PATH = "data/y_splicing.csv"
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py
new file mode 100644
index 0000000..8ba5961
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py
@@ -0,0 +1,26 @@
+import copy
+import numpy as np
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV
+from sklearn.utils.extmath import softmax
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusClassifier
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+
+
+rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42)
+ridge_model = RidgeClassifierPPM()
+logistic_model = LogisticClassifierPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})],
+ [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(logistic_model)})],
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py
new file mode 100644
index 0000000..490a687
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_tcga_cleaned.csv"
+Y_PATH = "data/Y_tcga.csv"
diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py
new file mode 100644
index 0000000..8ba5961
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py
@@ -0,0 +1,26 @@
+import copy
+import numpy as np
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV
+from sklearn.utils.extmath import softmax
+from feature_importance.util import ModelConfig
+from imodels.importance.rf_plus import RandomForestPlusClassifier
+from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM
+
+
+rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42)
+ridge_model = RidgeClassifierPPM()
+logistic_model = LogisticClassifierPPM()
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})],
+ [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(ridge_model)})],
+ [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree',
+ other_params={'rf_model': copy.deepcopy(rf_model),
+ 'prediction_model': copy.deepcopy(logistic_model)})],
+]
+
+FI_ESTIMATORS = []
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py
new file mode 100644
index 0000000..b5860cd
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv"
+Y_PATH = "data/y_ccle_rnaseq.csv"
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py
new file mode 100644
index 0000000..4b1dbf8
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('MDA', tree_mda, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree', splitting_strategy="train-test-prediction")]
+]
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py
new file mode 100644
index 0000000..490a687
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_tcga_cleaned.csv"
+Y_PATH = "data/Y_tcga.csv"
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py
new file mode 100644
index 0000000..061791f
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+from functools import partial
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction", other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': partial(_fast_r2_score, multiclass=True)})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('MDA', tree_mda, model_type='tree', splitting_strategy="train-test-prediction")],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree', splitting_strategy="train-test-prediction")]
+]
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py
new file mode 100644
index 0000000..b5860cd
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv"
+Y_PATH = "data/y_ccle_rnaseq.csv"
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py
new file mode 100644
index 0000000..4cde78f
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py
new file mode 100644
index 0000000..490a687
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py
@@ -0,0 +1,2 @@
+X_PATH = "data/X_tcga_cleaned.csv"
+Y_PATH = "data/Y_tcga.csv"
diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py
new file mode 100644
index 0000000..755fc01
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py
@@ -0,0 +1,20 @@
+from sklearn.ensemble import RandomForestClassifier
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+from imodels.importance.rf_plus import _fast_r2_score
+from imodels.importance.ppms import RidgeClassifierPPM
+from functools import partial
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestClassifier, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': partial(_fast_r2_score, multiclass=True)})],
+ [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..c57710e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py
new file mode 100644
index 0000000..193daff
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..166f958
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..a1b17f6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..5e3b23d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py
new file mode 100644
index 0000000..d09d33a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..919157d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..9c30166
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..34a80c8
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py
new file mode 100644
index 0000000..3a140eb
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..d8744e6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py
new file mode 100644
index 0000000..8d1c8b6
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py
@@ -0,0 +1 @@
+
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..34d4f3f
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_juvenile_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": None
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..6f2ef48
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py
@@ -0,0 +1,22 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py
new file mode 100644
index 0000000..339a98d
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py
@@ -0,0 +1,21 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s":5,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..48ca787
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2,
+ "s": 1,
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py
new file mode 100644
index 0000000..7204b75
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py
@@ -0,0 +1,23 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": 1000,
+ "sample_col_n": 100
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "tau": 0,
+ "m": 3,
+ "r": 2
+}
+
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}}
diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py
new file mode 100644
index 0000000..44f116c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py
@@ -0,0 +1,17 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')],
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..74a7fb8
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,25 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..817e499
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,25 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..0a6db88
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 5,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..e6b0b6a
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 5,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..f7cc517
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,27 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "tau": 0,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..1a55ba0
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,27 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "tau": 0,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..c0634ea
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,26 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..b36e69c
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,26 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_ccle_rnaseq_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": 1000
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "472": 472}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..d64bf10
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,25 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..8f87a35
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,25 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = hierarchical_poly
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..864f016
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 5,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..74e744b
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py
@@ -0,0 +1,24 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 5,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..a17116e
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,27 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "tau": 0,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..8659075
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,27 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = partial_linear_lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "tau": 0,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 1,
+ "m": 3,
+ "r": 2,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..25ab03f
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py
@@ -0,0 +1,26 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 10
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py
new file mode 100644
index 0000000..b3599de
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py
@@ -0,0 +1,26 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_enhancer_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = lss_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "m": 3,
+ "r": 2,
+ "tau": 0,
+ "corrupt_how": "leverage_normal",
+ "corrupt_size": 0.1,
+ "corrupt_mean": 25
+}
+
+VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
+VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}}
diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py
new file mode 100644
index 0000000..7d18df1
--- /dev/null
+++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py
@@ -0,0 +1,22 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss
+from sklearn.metrics import mean_absolute_error
+
+
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
+ [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/fi_config/test/__init__.py b/feature_importance/fi_config/test/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/fi_config/test/dgp.py b/feature_importance/fi_config/test/dgp.py
new file mode 100644
index 0000000..6f9eb20
--- /dev/null
+++ b/feature_importance/fi_config/test/dgp.py
@@ -0,0 +1,31 @@
+import sys
+sys.path.append("../..")
+from feature_importance.scripts.simulations_util import *
+
+
+X_DGP = sample_real_X
+X_PARAMS_DICT = {
+ "fpath": "data/X_splicing_cleaned.csv",
+ "sample_row_n": None,
+ "sample_col_n": None
+}
+Y_DGP = linear_model
+Y_PARAMS_DICT = {
+ "beta": 1,
+ "sigma": None,
+ "heritability": 0.4,
+ "s": 5
+}
+
+# # vary one parameter
+# VARY_PARAM_NAME = "sample_row_n"
+# VARY_PARAM_VALS = {"100": 100, "250": 250, "500": 500, "1000": 1000}
+
+# vary two parameters in a grid
+VARY_PARAM_NAME = ["heritability", "sample_row_n"]
+VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
+ "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000}}
+
+# # vary over n_estimators in RF model in models.py
+# VARY_PARAM_NAME = "n_estimators"
+# VARY_PARAM_VALS = {"placeholder": 0}
diff --git a/feature_importance/fi_config/test/models.py b/feature_importance/fi_config/test/models.py
new file mode 100644
index 0000000..6d3537c
--- /dev/null
+++ b/feature_importance/fi_config/test/models.py
@@ -0,0 +1,19 @@
+from sklearn.ensemble import RandomForestRegressor
+from feature_importance.util import ModelConfig, FIModelConfig
+from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
+
+# N_ESTIMATORS=[50, 100, 500, 1000]
+ESTIMATORS = [
+ [ModelConfig('RF', RandomForestRegressor, model_type='tree',
+ other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33})],
+ # [ModelConfig('RF', RandomForestRegressor, model_type='tree', vary_param="n_estimators", vary_param_val=m,
+ # other_params={'min_samples_leaf': 5, 'max_features': 0.33}) for m in N_ESTIMATORS]
+]
+
+FI_ESTIMATORS = [
+ [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')],
+ [FIModelConfig('MDI', tree_mdi, model_type='tree')],
+ [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
+ [FIModelConfig('MDA', tree_mda, model_type='tree')],
+ [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
+]
diff --git a/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd b/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd
new file mode 100644
index 0000000..55a46ef
--- /dev/null
+++ b/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd
@@ -0,0 +1,3205 @@
+---
+title: "MDI+ Simulation Results Summary"
+author: ""
+date: "`r format(Sys.time(), '%B %d, %Y')`"
+output: vthemes::vmodern
+params:
+ results_dir:
+ label: "Results directory"
+ value: "results/"
+ seed:
+ label: "Seed"
+ value: 12345
+ for_paper:
+ label: "Export plots for paper"
+ value: FALSE
+ use_cached:
+ label: "Use cached .rds files"
+ value: TRUE
+ interactive:
+ label: "Interactive plots"
+ value: FALSE
+---
+
+```{r setup, include=FALSE}
+knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)
+
+library(magrittr)
+library(patchwork)
+chunk_idx <- 1
+
+# set parameters
+results_dir <- params$results_dir
+seed <- params$seed
+tables_dir <- file.path("tables")
+figures_dir <- file.path("figures")
+figures_subdirs <- c("regression_sims", "classification_sims", "robust_sims",
+ "misspecified_regression_sims", "varying_sparsity",
+ "varying_p", "modeling_choices", "glm_metric_choices")
+for (figures_subdir in figures_subdirs) {
+ if (!dir.exists(file.path(figures_dir, figures_subdir))) {
+ dir.create(file.path(figures_dir, figures_subdir), recursive = TRUE)
+ }
+}
+
+# miscellaneous helper variables
+heritabilities <- c(0.1, 0.2, 0.4, 0.8)
+frac_label_corruptions <- c("0.25", "0.15", "0.05", "0")
+corrupt_sizes <- c(0.05, 0.025, 0.01, 0)
+mean_shifts <- c(10, 25)
+metric <- "rocauc"
+
+# plot options
+point_size <- 2
+line_size <- 1
+errbar_width <- 0
+if (params$interactive) {
+ plot_fun <- plotly::ggplotly
+} else {
+ plot_fun <- function(x) x
+}
+
+manual_color_palette_choices <- c(
+ "black", "black", "#9B5DFF", "blue",
+ "orange", "#71beb7", "#218a1e", "#cc3399"
+)
+show_methods_choices <- c(
+ "MDI+_ridge_RF", "MDI+_ridge_loo_r2_RF", "MDI+_logistic_logloss_RF", "MDI+_Huber_loo_huber_loss_RF",
+ "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF"
+)
+method_labels_choices <- c(
+ "MDI+ (ridge)", "MDI+ (ridge)", "MDI+ (logistic)", "MDI+ (Huber)",
+ "MDA", "MDI", "MDI-oob", "TreeSHAP"
+)
+color_df <- tibble::tibble(
+ color = manual_color_palette_choices,
+ name = show_methods_choices,
+ label = method_labels_choices
+)
+
+manual_color_palette_all <- NULL
+show_methods_all <- NULL
+method_labels_all <- ggplot2::waiver()
+
+custom_theme <- vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 20, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ plot.title = ggplot2::element_blank()
+ # plot.title = ggplot2::element_text(size = 12, face = "plain", hjust = 0.5)
+)
+custom_theme_with_legend <- vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.title = ggplot2::element_text(size = 12, face = "plain"),
+ legend.text = ggplot2::element_text(size = 9),
+ legend.text.align = 0,
+ plot.title = ggplot2::element_blank()
+)
+custom_theme_without_legend <- vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.title = ggplot2::element_text(size = 12, face = "plain"),
+ legend.title = ggplot2::element_blank(),
+ legend.text = ggplot2::element_text(size = 9),
+ legend.text.align = 0,
+ plot.title = ggplot2::element_blank()
+)
+
+fig_height <- 6
+fig_width <- 10
+
+source("../../scripts/viz.R", chdir = TRUE)
+```
+
+
+# Regression Simulations {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
+legend_position <- c(0.73, 0.35)
+
+vary_param_name <- "heritability_sample_row_n"
+y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
+
+remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
+keep_legend_x_models <- c("enhancer")
+keep_legend_y_models <- y_models
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s \n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "regression_sims",
+ sprintf("regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+}
+
+```
+
+```{r eval = TRUE, results = "asis"}
+y_models <- c("lss_3m_2r", "hier_poly_3m_2r")
+x_models <- c("splicing")
+
+remove_x_axis_models <- c("lss_3m_2r")
+keep_legend_x_models <- x_models
+keep_legend_y_models <- c("lss_3m_2r")
+
+if (params$for_paper) {
+ for (y_model in y_models) {
+ for (x_model in x_models) {
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (y_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "regression_sims",
+ sprintf("main_regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ }
+ }
+}
+
+```
+
+
+# Classification Simulations {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDI+_logistic_logloss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
+legend_position <- c(0.73, 0.4)
+
+vary_param_name <- "frac_label_corruption_sample_row_n"
+y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
+x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
+
+remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
+keep_legend_x_models <- c("enhancer")
+keep_legend_y_models <- y_models
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s \n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in frac_label_corruptions) {
+ plt <- results %>%
+ dplyr::filter(frac_label_corruption_name == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != frac_label_corruptions[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ height <- 2.72
+ } else {
+ height <- 3
+ }
+ if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "classification_sims",
+ sprintf("classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = "frac_label_corruption_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+}
+```
+
+```{r eval = TRUE, results = "asis"}
+y_models <- c("logistic", "linear_lss_3m_2r_logistic")
+x_models <- c("ccle_rnaseq")
+
+remove_x_axis_models <- c("logistic")
+keep_legend_x_models <- x_models
+keep_legend_y_models <- c("logistic")
+
+if (params$for_paper) {
+ for (y_model in y_models) {
+ for (x_model in x_models) {
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ for (h in frac_label_corruptions) {
+ plt <- results %>%
+ dplyr::filter(frac_label_corruption_name == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != frac_label_corruptions[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (y_model %in% remove_x_axis_models) {
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ height <- 2.72
+ } else {
+ height <- 3
+ }
+ if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "classification_sims",
+ sprintf("main_classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ }
+ }
+}
+```
+
+
+# Robust Simulations {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_loo_r2_RF", "MDI+_Huber_loo_huber_loss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
+legend_position <- c(0.73, 0.4)
+
+vary_param_name <- "corrupt_size_sample_row_n"
+y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq")
+
+remove_x_axis_models <- c(10)
+keep_legend_x_models <- c("enhancer", "ccle_rnaseq")
+keep_legend_y_models <- c("linear", "linear_lss_3m_2r")
+keep_legend_mean_shifts <- c(10)
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle} \n\n", x_model))
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ for (mean_shift in mean_shifts) {
+ cat(sprintf("\n\n#### Mean Shift = %s \n\n", mean_shift))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
+ fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC",
+ TRUE ~ metric
+ )
+ if (params$for_paper) {
+ tmp <- results %>%
+ dplyr::filter(corrupt_size_name %in% corrupt_sizes,
+ method %in% show_methods) %>%
+ dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
+ dplyr::summarise(mean = mean(.data[[metric]]))
+ min_y <- min(tmp$mean)
+ max_y <- max(tmp$mean)
+ for (h in corrupt_sizes) {
+ plt <- results %>%
+ dplyr::filter(corrupt_size_name == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != corrupt_sizes[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (mean_shift %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == corrupt_sizes[length(corrupt_sizes)]) &
+ (mean_shift %in% keep_legend_mean_shifts) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "robust_sims",
+ sprintf("robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = "corrupt_size_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+ }
+}
+
+```
+
+```{r eval = TRUE, results = "asis"}
+y_models <- c("lss_3m_2r")
+x_models <- c("enhancer")
+
+remove_x_axis_models <- c(10)
+keep_legend_x_models <- x_models
+keep_legend_y_models <- y_models
+keep_legend_mean_shifts <- c(10)
+
+if (params$for_paper) {
+ for (y_model in y_models) {
+ for (x_model in x_models) {
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ for (mean_shift in mean_shifts) {
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
+ fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC",
+ TRUE ~ metric
+ )
+ tmp <- results %>%
+ dplyr::filter(corrupt_size_name %in% corrupt_sizes,
+ method %in% show_methods) %>%
+ dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
+ dplyr::summarise(mean = mean(.data[[metric]]))
+ min_y <- min(tmp$mean)
+ max_y <- max(tmp$mean)
+ for (h in corrupt_sizes) {
+ plt <- results %>%
+ dplyr::filter(corrupt_size_name == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != corrupt_sizes[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (mean_shift %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == corrupt_sizes[length(corrupt_sizes)]) &
+ (mean_shift %in% keep_legend_mean_shifts) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "robust_sims",
+ sprintf("main_robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ }
+ }
+ }
+}
+
+```
+
+
+# Correlation Bias Simulations {.tabset .tabset-vmodern}
+
+## Main Figures {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("black", "#218a1e", "orange", "#71beb7", "#cc3399")
+show_methods <- c("MDI+", "MDI-oob", "MDA", "MDI", "TreeSHAP")
+method_labels <- c("MDI+ (ridge)", "MDI-oob", "MDA", "MDI", "TreeSHAP")
+
+custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
+keep_heritabilities <- c(0.1, 0.4)
+sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_heritability_rho",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name %in% keep_heritabilities)
+
+n <- 250
+p <- 100
+sig_ids <- 0:5
+cnsig_ids <- 6:49
+nreps <- length(unique(results$rep))
+
+#### Examine average rank across varying correlation levels ####
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "heritability_name",
+ facet_cols = "rho_name",
+ param_name = NULL,
+ sig_ids = sig_ids,
+ cnsig_ids = cnsig_ids,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+for (idx in 1:length(keep_heritabilities)) {
+ h <- keep_heritabilities[idx]
+ cat(sprintf("\n\n### PVE = %s\n\n", h))
+ plt_df <- plt_base$agg[[idx]]$data
+ plt_ls <- list()
+ for (method in method_labels) {
+ plt_ls[[method]] <- plt_df %>%
+ dplyr::filter(fi == method) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = rho_name, y = .mean, color = group) +
+ ggplot2::geom_line() +
+ ggplot2::geom_ribbon(
+ ggplot2::aes(x = rho_name,
+ ymin = .mean - (.sd / sqrt(nreps)),
+ ymax = .mean + (.sd / sqrt(nreps)),
+ fill = group),
+ inherit.aes = FALSE, alpha = 0.2
+ ) +
+ ggplot2::labs(
+ x = expression("Correlation ("*rho*")"),
+ y = "Average Rank",
+ color = "Feature\nGroup",
+ title = method
+ ) +
+ ggplot2::coord_cartesian(
+ ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
+ ) +
+ ggplot2::scale_color_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::scale_fill_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::guides(fill = "none", linetype = "none") +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+ if (method != method_labels[1]) {
+ plt_ls[[method]] <- plt_ls[[method]] +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank(),
+ axis.ticks.y = ggplot2::element_blank(),
+ axis.text.y = ggplot2::element_blank())
+ }
+ }
+
+ plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
+ if (params$for_paper) {
+ ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide.pdf", h)),
+ # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
+ height = fig_height * .55, width = fig_width * .3 * length(show_methods))
+ } else {
+ vthemes::subchunkify(
+ plt_wide, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+ }
+}
+
+
+#### Examine number of splits across varying correlation levels ####
+cat("\n\n### Number of RF Splits \n\n")
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name %in% keep_heritabilities,
+ fi == "MDI_with_splits")
+
+total_splits <- results %>%
+ dplyr::group_by(rho_name, heritability, rep) %>%
+ dplyr::summarise(n_splits = sum(av_splits))
+
+plt_df <- results %>%
+ dplyr::left_join(
+ total_splits, by = c("rho_name", "heritability", "rep")
+ ) %>%
+ dplyr::mutate(
+ av_splits = av_splits / n_splits * 100,
+ group = dplyr::case_when(
+ var %in% sig_ids ~ "Sig",
+ var %in% cnsig_ids ~ "C-NSig",
+ TRUE ~ "NSig"
+ )
+ ) %>%
+ dplyr::mutate(
+ group = factor(group, levels = c("Sig", "C-NSig", "NSig"))
+ ) %>%
+ dplyr::group_by(rho_name, heritability, group) %>%
+ dplyr::summarise(
+ .mean = mean(av_splits),
+ .sd = sd(av_splits)
+ )
+
+min_y <- min(plt_df$.mean)
+max_y <- max(plt_df$.mean)
+
+plt_ls <- list()
+for (idx in 1:length(keep_heritabilities)) {
+ h <- keep_heritabilities[idx]
+ plt <- plt_df %>%
+ dplyr::filter(heritability == !!h) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = rho_name, y = .mean, color = group) +
+ ggplot2::geom_line(size = 1) +
+ # ggplot2::geom_ribbon(
+ # ggplot2::aes(x = rho_name,
+ # ymin = .mean - (.sd / sqrt(nreps)),
+ # ymax = .mean + (.sd / sqrt(nreps)),
+ # fill = group),
+ # inherit.aes = FALSE, alpha = 0.2
+ # ) +
+ ggplot2::labs(
+ x = expression("Correlation ("*rho*")"),
+ y = "Percentage of Splits in RF (%)",
+ color = "Feature\nGroup",
+ title = sprintf("PVE = %s", h)
+ ) +
+ ggplot2::coord_cartesian(ylim = c(min_y, max_y)) +
+ ggplot2::scale_color_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::scale_fill_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::guides(fill = "none", linetype = "none") +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+ if (idx != 1) {
+ plt <- plt +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ plt_ls[[idx]] <- plt
+}
+
+plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
+if (params$for_paper) {
+ ggplot2::ggsave(plt_wide & ggplot2::theme(plot.title = ggplot2::element_blank()),
+ filename = file.path(figures_dir, "correlation_sim_num_splits.pdf"),
+ width = fig_width * 1, height = fig_height * .65)
+} else {
+ vthemes::subchunkify(
+ plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+}
+
+```
+
+## Appendix Figures {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("black", "gray")
+show_methods <- c("MDI+", "MDI+_inbag")
+method_labels <- c("MDI+ (LOO)", "MDI+ (in-bag)")
+
+custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
+keep_heritabilities <- c(0.1, 0.4)
+sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_heritability_rho",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name %in% keep_heritabilities)
+
+n <- 250
+p <- 100
+sig_ids <- 0:5
+cnsig_ids <- 6:49
+nreps <- length(unique(results$rep))
+
+#### Examine average rank across varying correlation levels ####
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "heritability_name",
+ facet_cols = "rho_name",
+ param_name = NULL,
+ sig_ids = sig_ids,
+ cnsig_ids = cnsig_ids,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+for (idx in 1:length(keep_heritabilities)) {
+ h <- keep_heritabilities[idx]
+ cat(sprintf("\n\n### PVE = %s\n\n", h))
+ plt_df <- plt_base$agg[[idx]]$data
+ plt_ls <- list()
+ for (method in method_labels) {
+ plt_ls[[method]] <- plt_df %>%
+ dplyr::filter(fi == method) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = rho_name, y = .mean, color = group) +
+ ggplot2::geom_line() +
+ ggplot2::geom_ribbon(
+ ggplot2::aes(x = rho_name,
+ ymin = .mean - (.sd / sqrt(nreps)),
+ ymax = .mean + (.sd / sqrt(nreps)),
+ fill = group),
+ inherit.aes = FALSE, alpha = 0.2
+ ) +
+ ggplot2::labs(
+ x = expression("Correlation ("*rho*")"),
+ y = "Average Rank",
+ color = "Feature\nGroup",
+ title = method
+ ) +
+ ggplot2::coord_cartesian(
+ ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
+ ) +
+ ggplot2::scale_color_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::scale_fill_manual(
+ values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
+ ) +
+ ggplot2::guides(fill = "none", linetype = "none") +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+ if (method != method_labels[1]) {
+ plt_ls[[method]] <- plt_ls[[method]] +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank(),
+ axis.ticks.y = ggplot2::element_blank(),
+ axis.text.y = ggplot2::element_blank())
+ }
+ }
+
+ plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
+ if (params$for_paper) {
+ ggplot2::ggsave(plt_wide,
+ filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide_appendix.pdf", h)),
+ # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
+ height = fig_height * .55, width = fig_width * .37 * length(show_methods))
+ } else {
+ vthemes::subchunkify(
+ plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+ }
+}
+```
+
+
+# Entropy Bias Simulations {.tabset .tabset-vmodern}
+
+## Main Figures {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+#### Entropy Regression Results ####
+manual_color_palette <- rev(c("black", "orange", "#71beb7", "#218a1e", "#cc3399"))
+show_methods <- rev(c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
+method_labels <- rev(c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
+alpha_values <- rev(c(1, rep(0.7, length(method_labels) - 1)))
+y_limits <- c(1, 4.5)
+
+metric_name <- "AUROC"
+sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_heritability_n",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+
+#### Entropy Regression Avg Rank ####
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name == 0.1)
+
+custom_group_fun <- function(var) var
+
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "heritability_name",
+ facet_cols = "n_name",
+ group_fun = custom_group_fun,
+ param_name = NULL,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+plt_df <- plt_base$agg[[1]]$data
+nreps <- length(unique(results$rep))
+
+plt_regression <- plt_df %>%
+ dplyr::filter(group == 0) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
+ ggplot2::geom_line(size = line_size) +
+ ggplot2::labs(
+ x = "Sample Size",
+ y = expression("Average Rank of "*X[1]),
+ color = "Method",
+ title = "Regression"
+ ) +
+ ggplot2::guides(fill = "none", alpha = "none") +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_alpha_manual(values = alpha_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::coord_cartesian(ylim = y_limits) +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+
+#### Entropy Regression # Splits ####
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name == 0.1,
+ fi == "MDI_with_splits")
+
+nreps <- length(unique(results$rep))
+
+total_splits <- results %>%
+ dplyr::group_by(n_name, heritability, rep) %>%
+ dplyr::summarise(n_splits = sum(av_splits))
+
+regression_num_splits_df <- results %>%
+ dplyr::left_join(
+ total_splits, by = c("n_name", "heritability", "rep")
+ ) %>%
+ dplyr::mutate(
+ av_splits = av_splits / n_splits * 100
+ ) %>%
+ dplyr::group_by(n_name, heritability, var) %>%
+ dplyr::summarise(
+ .mean = mean(av_splits),
+ .sd = sd(av_splits)
+ )
+
+#### Entropy Classification Results ####
+manual_color_palette <- rev(c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399"))
+show_methods <- rev(c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
+method_labels <- rev(c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
+alpha_values <- rev(c(1, 1, rep(0.7, length(method_labels) - 2)))
+
+metric_name <- "AUROC"
+sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_c_n",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+
+#### Entropy Classification Avg Rank ####
+results <- data.table::fread(fname)
+
+custom_group_fun <- function(var) var
+
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "c_name",
+ facet_cols = "n_name",
+ group_fun = custom_group_fun,
+ param_name = NULL,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+plt_df <- plt_base$agg[[1]]$data
+nreps <- length(unique(results$rep))
+
+plt_classification <- plt_df %>%
+ dplyr::filter(group == 0) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
+ ggplot2::geom_line(size = line_size) +
+ ggplot2::labs(
+ x = "Sample Size",
+ y = expression("Average Rank of "*X[1]),
+ color = "Method",
+ title = "Classification"
+ ) +
+ ggplot2::guides(fill = "none", alpha = "none") +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_alpha_manual(values = alpha_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::coord_cartesian(
+ ylim = y_limits
+ ) +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+
+#### Entropy Classification # Splits ####
+results <- data.table::fread(fname) %>%
+ dplyr::filter(fi == "MDI_with_splits")
+
+total_splits <- results %>%
+ dplyr::group_by(n_name, c, rep) %>%
+ dplyr::summarise(n_splits = sum(av_splits))
+
+classification_num_splits_df <- results %>%
+ dplyr::left_join(
+ total_splits, by = c("n_name", "c", "rep")
+ ) %>%
+ dplyr::mutate(
+ av_splits = av_splits / n_splits * 100
+ ) %>%
+ dplyr::group_by(n_name, c, var) %>%
+ dplyr::summarise(
+ .mean = mean(av_splits),
+ .sd = sd(av_splits)
+ )
+
+#### Show Avg Rank Plot ####
+cat("\n\n### Average Rank \n\n")
+plt <- plt_regression + plt_classification
+if (params$for_paper) {
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims.pdf"),
+ width = fig_width * 1.2, height = fig_height * .65)
+} else {
+ vthemes::subchunkify(
+ plt, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+}
+
+#### Show # Splits Plot ####
+cat("\n\n### Number of RF Splits \n\n")
+plt_df <- dplyr::bind_rows(
+ Regression = regression_num_splits_df,
+ Classification = classification_num_splits_df,
+ .id = "type"
+) %>%
+ dplyr::mutate(
+ var = dplyr::case_when(
+ var == 0 ~ "Bernoulli (Signal)",
+ var == 1 ~ "Normal (Non-Signal)",
+ var == 2 ~ "4-Categories (Non-Signal)",
+ var == 3 ~ "10-Categories (Non-Signal)",
+ var == 4 ~ "20-Categories (Non-Signal)"
+ ) %>%
+ factor(levels = c("Normal (Non-Signal)",
+ "20-Categories (Non-Signal)",
+ "10-Categories (Non-Signal)",
+ "4-Categories (Non-Signal)",
+ "Bernoulli (Signal)"))
+ )
+
+sim_types <- unique(plt_df$type)
+plt_ls <- list()
+for (idx in 1:length(sim_types)) {
+ sim_type <- sim_types[idx]
+ plt <- plt_df %>%
+ dplyr::filter(type == !!sim_type) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = n_name, y = .mean, color = as.factor(var)) +
+ ggplot2::geom_line(size = 1) +
+ ggplot2::labs(
+ x = "Sample Size",
+ y = "Percentage of Splits in RF (%)",
+ color = "Feature",
+ title = sim_type
+ ) +
+ ggplot2::guides(fill = "none", linetype = "none") +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
+ )
+ if (idx != 1) {
+ plt <- plt +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ plt_ls[[idx]] <- plt
+}
+
+plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
+if (params$for_paper) {
+ ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, "entropy_sims_num_splits.pdf"),
+ # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
+ width = fig_width * 1.1, height = fig_height * .65)
+} else {
+ vthemes::subchunkify(
+ plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+}
+
+```
+
+## Appendix Figures {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+#### Entropy Regression Results ####
+manual_color_palette <- rev(c("black", "black", "#71beb7"))
+show_methods <- rev(c("MDI+", "MDI+_inbag", "MDI"))
+method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI"))
+alpha_values <- rev(c(1, 1, rep(1, length(method_labels) - 2)))
+linetype_values <- c("solid", "dashed", "solid")
+y_limits <- c(1, 5)
+
+metric_name <- "AUROC"
+sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_heritability_n",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+
+#### Entropy Regression Avg Rank ####
+results <- data.table::fread(fname) %>%
+ dplyr::filter(heritability_name == 0.1)
+
+custom_group_fun <- function(var) var
+
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "heritability_name",
+ facet_cols = "n_name",
+ group_fun = custom_group_fun,
+ param_name = NULL,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+plt_df <- plt_base$agg[[1]]$data
+nreps <- length(unique(results$rep))
+
+plt_regression <- plt_df %>%
+ dplyr::filter(group == 0) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
+ ggplot2::geom_line(size = line_size) +
+ ggplot2::labs(
+ x = "Sample Size",
+ y = expression("Average Rank of "*X[1]),
+ color = "Method",
+ linetype = "Method",
+ title = "Regression"
+ ) +
+ ggplot2::guides(fill = "none", alpha = "none") +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_alpha_manual(values = alpha_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_linetype_manual(values = linetype_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::coord_cartesian(ylim = y_limits) +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
+ legend.key.width = ggplot2::unit(1, "cm")
+ )
+
+#### Entropy Classification Results ####
+manual_color_palette <- rev(c("black", "black", "#9B5DFF", "#9B5DFF", "#71beb7"))
+show_methods <- rev(c("MDI+_ridge", "MDI+_ridge_inbag", "MDI+_logistic_logloss", "MDI+_logistic_logloss_inbag", "MDI"))
+method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI+ (logistic, LOO)", "MDI+ (logistic, in-bag)", "MDI"))
+alpha_values <- rev(c(1, 1, 1, 1, rep(1, length(method_labels) - 4)))
+linetype_values <- c("solid", "dashed", "solid", "dashed", "solid")
+
+metric_name <- "AUROC"
+sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
+fname <- file.path(
+ results_dir,
+ sim_name,
+ "varying_c_n",
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+
+#### Entropy Classification Avg Rank ####
+results <- data.table::fread(fname)
+
+custom_group_fun <- function(var) var
+
+plt_base <- plot_perturbation_stability(
+ results,
+ facet_rows = "c_name",
+ facet_cols = "n_name",
+ group_fun = custom_group_fun,
+ param_name = NULL,
+ plot_types = "errbar",
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+plt_df <- plt_base$agg[[1]]$data
+nreps <- length(unique(results$rep))
+
+plt_classification <- plt_df %>%
+ dplyr::filter(group == 0) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
+ ggplot2::geom_line(size = line_size) +
+ ggplot2::labs(
+ x = "Sample Size",
+ y = expression("Average Rank of "*X[1]),
+ color = "Method",
+ linetype = "Method",
+ title = "Classification"
+ ) +
+ ggplot2::guides(fill = "none", alpha = "none") +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_alpha_manual(values = alpha_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::scale_linetype_manual(values = linetype_values,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE)) +
+ ggplot2::coord_cartesian(
+ ylim = y_limits
+ ) +
+ vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
+ axis.text = ggplot2::element_text(size = 14),
+ axis.title = ggplot2::element_text(size = 18, face = "plain"),
+ legend.text = ggplot2::element_text(size = 14),
+ legend.title = ggplot2::element_text(size = 18),
+ plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
+ legend.key.width = ggplot2::unit(1, "cm")
+ )
+
+#### Show Avg Rank Plot ####
+cat("\n\n### Average Rank \n\n")
+plt <- plt_regression + plt_classification
+if (params$for_paper) {
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims_appendix.pdf"),
+ width = fig_width * 1.35, height = fig_height * .65)
+} else {
+ vthemes::subchunkify(
+ plt, i = chunk_idx, fig_height = fig_height * .7, fig_width = fig_width * 1.3
+ )
+ chunk_idx <- chunk_idx + 1
+}
+
+```
+
+
+# CCLE Case Study {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
+show_methods <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+method_labels <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
+
+fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
+fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+pred_fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "pred_results.csv"
+)
+X <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv")
+X_columns <- colnames(X)
+results <- data.table::fread(fname)
+pred_results <- data.table::fread(pred_fname) %>%
+ tibble::as_tibble() %>%
+ dplyr::mutate(method = fi) %>%
+ # tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>%
+ tidyr::pivot_wider(names_from = metric, values_from = metric_value)
+
+out <- plot_top_stability(
+ results,
+ group_id = "y_task",
+ top_r = 10,
+ show_max_features = 5,
+ varnames = colnames(X),
+ base_method = "MDI+ (ridge)",
+ return_df = TRUE,
+ manual_color_palette = rev(manual_color_palette),
+ show_methods = rev(show_methods),
+ method_labels = rev(method_labels)
+)
+
+ranking_df <- out$rankings
+stability_df <- out$stability
+plt_ls <- out$plot_ls
+
+# get gene names instead of ENSG ids
+library(EnsDb.Hsapiens.v79)
+varnames <- ensembldb::select(
+ EnsDb.Hsapiens.v79,
+ keys = stringr::str_remove(colnames(X), "\\.[^\\.]+$"),
+ keytype = "GENEID", columns = c("GENEID","SYMBOL")
+)
+
+# get top 5 genes
+top_genes <- ranking_df %>%
+ dplyr::group_by(y_task, fi, var) %>%
+ dplyr::summarise(mean_rank = mean(rank)) %>%
+ dplyr::ungroup() %>%
+ dplyr::arrange(y_task, mean_rank) %>%
+ dplyr::group_by(y_task, fi) %>%
+ dplyr::mutate(rank = 1:dplyr::n()) %>%
+ dplyr::ungroup()
+
+top_genes_table <- top_genes %>%
+ dplyr::filter(rank <= 50) %>%
+ dplyr::left_join(
+ y = data.frame(var = 0:(ncol(X) - 1),
+ GENEID = stringr::str_remove(colnames(X), "\\.[^\\.]+$")),
+ by = "var"
+ ) %>%
+ dplyr::left_join(
+ y = varnames, by = "GENEID"
+ ) %>%
+ dplyr::mutate(
+ value = sprintf("%s (%s)", SYMBOL, round(mean_rank, 2))
+ ) %>%
+ tidyr::pivot_wider(
+ id_cols = c(y_task, rank),
+ names_from = fi,
+ values_from = value
+ ) %>%
+ dplyr::rename(
+ "Drug" = "y_task",
+ "Rank" = "rank"
+ )
+
+# write.csv(
+# top_genes_table, file.path(tables_dir, "ccle_rnaseq_top_genes_table.csv"),
+# row.names = FALSE, quote = FALSE
+# )
+
+# get # genes with perfect stability
+stable_genes_kable <- stability_df %>%
+ dplyr::group_by(fi, y_task) %>%
+ dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
+ tidyr::pivot_wider(id_cols = fi, names_from = y_task, values_from = stability_1) %>%
+ dplyr::ungroup() %>%
+ as.data.frame()
+
+pred_plt_ls <- list()
+for (plt_id in names(plt_ls)) {
+ cat(sprintf("\n\n## %s \n\n", plt_id))
+ plt <- plt_ls[[plt_id]]
+ if (plt_id == "Summary") {
+ plt <- plt + ggplot2::labs(y = "Drug")
+ if (params$for_paper) {
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "ccle_rnaseq_stability_top10.pdf"),
+ width = fig_width * 1.5, height = fig_height * 1.1)
+ } else {
+ vthemes::subchunkify(
+ plt, i = sprintf("%s-%s", fpath, plt_id),
+ fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ }
+
+ cat("\n\n## Table of Top Genes \n\n")
+ if (params$for_paper) {
+ top_genes_kable <- top_genes_table %>%
+ dplyr::filter(Rank <= 5) %>%
+ dplyr::select(-Drug) %>%
+ dplyr::rename(" " = "Rank") %>%
+ vthemes::pretty_kable(format = "latex", longtable = TRUE) %>%
+ kableExtra::kable_styling(latex_options = "repeat_header") %>%
+ # kableExtra::collapse_rows(columns = 1, valign = "middle")
+ kableExtra::pack_rows(
+ index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
+ )
+ } else {
+ top_genes_kable <- top_genes_table %>%
+ dplyr::filter(Rank <= 5) %>%
+ dplyr::select(-Drug) %>%
+ dplyr::rename(" " = "Rank") %>%
+ vthemes::pretty_kable() %>%
+ kableExtra::kable_styling(latex_options = "repeat_header") %>%
+ # kableExtra::collapse_rows(columns = 1, valign = "middle")
+ kableExtra::pack_rows(
+ index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
+ )
+ vthemes::subchunkify(
+ top_genes_kable, i = sprintf("%s-table", fpath)
+ )
+ }
+ } else {
+ pred_plt_ls[[plt_id]] <- pred_results %>%
+ dplyr::filter(k <= 10, y_task == !!plt_id) %>%
+ plot_metrics(
+ metric = "r2",
+ x_str = "k",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ inside_legend = FALSE
+ ) +
+ ggplot2::labs(x = "Number of Top Features (k)", y = "Test R-squared",
+ color = "Method", alpha = "Method", title = plt_id) +
+ vthemes::theme_vmodern()
+
+ if (!params$for_paper) {
+ vthemes::subchunkify(
+ plot_fun(plt), i = sprintf("%s-%s", fpath, plt_id),
+ fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ vthemes::subchunkify(
+ plot_fun(pred_plt_ls[[plt_id]]), i = sprintf("%s-%s-pred", fpath, plt_id),
+ fig_height = fig_height, fig_width = fig_width * 1.3
+ )
+ }
+ }
+}
+
+if (params$for_paper) {
+ all_plt_ls <- list()
+ for (idx in 1:length(pred_plt_ls)) {
+ drug <- names(pred_plt_ls)[idx]
+ plt <- pred_plt_ls[[drug]]
+ if ((idx - 1) %% 4 != 0) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if ((idx - 1) %/% 4 != 5) {
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ }
+ all_plt_ls[[drug]] <- plt
+ }
+ plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(hjust = 0.5),
+ panel.background = ggplot2::element_rect(fill = "white"),
+ panel.grid.major = ggplot2::element_line(color = "white")
+ )
+ ggplot2::ggsave(plt,
+ filename = file.path(figures_dir, "ccle_rnaseq_predictions.pdf"),
+ width = fig_width * 1, height = fig_height * 2)
+
+ pred10_ranks <- pred_results %>%
+ dplyr::filter(k == 10) %>%
+ dplyr::group_by(fi, y_task) %>%
+ dplyr::summarise(r2 = mean(r2)) %>%
+ tidyr::pivot_wider(names_from = y_task, values_from = r2) %>%
+ dplyr::ungroup() %>%
+ dplyr::mutate(dplyr::across(`17-AAG`:`ZD-6474`, ~rank(-.x)))
+ pred10_ranks_table <- apply(pred10_ranks, 1, FUN = function(x) table(x[-1])) %>%
+ tibble::as_tibble() %>%
+ tibble::rownames_to_column("Rank") %>%
+ setNames(c("Rank", pred10_ranks$fi)) %>%
+ dplyr::select(Rank, `MDI+`, TreeSHAP, MDI, `MDI-oob`, MDA) %>%
+ vthemes::pretty_kable(format = "latex")
+ top_genes_kable <- top_genes_table %>%
+ dplyr::filter(Rank <= 5) %>%
+ dplyr::select(-Drug) %>%
+ dplyr::rename(" " = "Rank") %>%
+ vthemes::pretty_kable() %>%
+ kableExtra::kable_styling(latex_options = "repeat_header") %>%
+ # kableExtra::collapse_rows(columns = 1, valign = "middle")
+ kableExtra::pack_rows(
+ index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
+ )
+ vthemes::subchunkify(
+ top_genes_kable, i = sprintf("%s-table", fpath)
+ )
+}
+```
+
+```{r eval = TRUE, results = "asis"}
+if (params$for_paper) {
+ manual_color_palette_full <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
+ show_methods_full <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+ method_labels_full <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+ keep_methods_list <- list(
+ all = show_methods_full,
+ no_mdioob = c("MDI+", "MDA", "MDI", "TreeSHAP"),
+ zoom = c("MDI+", "MDI", "TreeSHAP")
+ )
+
+ fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
+ fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+ )
+ X <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv")
+ results <- data.table::fread(fname)
+
+ for (id in names(keep_methods_list)) {
+ keep_methods <- keep_methods_list[[id]]
+ manual_color_palette <- manual_color_palette_full[show_methods_full %in% keep_methods]
+ show_methods <- show_methods_full[show_methods_full %in% keep_methods]
+ method_labels <- method_labels_full[show_methods_full %in% keep_methods]
+
+ plt_ls <- plot_top_stability(
+ results,
+ group_id = "y_task",
+ top_r = 10,
+ show_max_features = 5,
+ varnames = colnames(X),
+ base_method = "MDI+",
+ return_df = FALSE,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+ )
+ all_plt_ls <- list()
+ for (idx in 1:length(unique(results$y_task))) {
+ drug <- sort(unique(results$y_task))[idx]
+ plt <- plt_ls[[drug]]
+ if ((idx - 1) %% 4 != 0) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if ((idx - 1) %/% 4 != 5) {
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ }
+ all_plt_ls[[drug]] <- plt
+ }
+ plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(hjust = 0.5),
+ panel.background = ggplot2::element_rect(fill = "white"),
+ panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
+ )
+ ggplot2::ggsave(plt,
+ filename = file.path(figures_dir, sprintf("ccle_rnaseq_stability_select_features_%s.pdf", id)),
+ width = fig_width * 1, height = fig_height * 2)
+ }
+}
+```
+
+
+# TCGA BRCA Case Study {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#cc3399")
+show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
+method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "TreeSHAP")
+alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
+
+fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
+fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+)
+X <- data.table::fread("../../X_tcga_cleaned.csv")
+results <- data.table::fread(fname)
+out <- plot_top_stability(
+ results,
+ group_id = NULL,
+ top_r = 10,
+ show_max_features = 20,
+ varnames = colnames(X),
+ base_method = "MDI+ (ridge)",
+ # descending_methods = "MDI+ (logistic)",
+ return_df = TRUE,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels
+)
+
+ranking_df <- out$rankings
+stability_df <- out$stability
+plt_ls <- out$plot_ls
+
+cat("\n\n## Non-zero Stability Scores Per Method\n\n")
+vthemes::subchunkify(
+ plt_ls[[1]], i = sprintf("%s-%s", fpath, 1),
+ fig_height = fig_height * round(length(unique(results$fi)) / 2), fig_width = fig_width
+)
+
+cat("\n\n## Stability of Top Features\n\n")
+vthemes::subchunkify(
+ plot_fun(plt_ls[[2]]), i = sprintf("%s-%s", fpath, 2),
+ fig_height = fig_height, fig_width = fig_width
+)
+
+# get top 25 genes
+top_genes <- ranking_df %>%
+ dplyr::group_by(fi, var) %>%
+ dplyr::summarise(mean_rank = mean(rank)) %>%
+ dplyr::ungroup() %>%
+ dplyr::arrange(mean_rank) %>%
+ dplyr::group_by(fi) %>%
+ dplyr::mutate(rank = 1:dplyr::n()) %>%
+ dplyr::ungroup()
+
+top_genes_table <- top_genes %>%
+ dplyr::filter(rank <= 25) %>%
+ dplyr::left_join(
+ y = data.frame(var = 0:(ncol(X) - 1),
+ gene = colnames(X)),
+ by = "var"
+ ) %>%
+ dplyr::mutate(
+ value = sprintf("%s (%s)", gene, round(mean_rank, 2))
+ ) %>%
+ tidyr::pivot_wider(
+ id_cols = rank,
+ names_from = fi,
+ values_from = value
+ ) %>%
+ dplyr::rename(
+ "Rank" = "rank"
+ )
+
+# write.csv(
+# top_genes_table, file.path(tables_dir, "tcga_top_genes_table.csv"),
+# row.names = FALSE, quote = FALSE
+# )
+
+stable_genes_kable <- stability_df %>%
+ dplyr::group_by(fi) %>%
+ dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
+ tibble::as_tibble()
+
+cat("\n\n## Table of Top Genes \n\n")
+if (params$for_paper) {
+ top_genes_kable <- top_genes_table %>%
+ vthemes::pretty_kable(format = "latex")#, longtable = TRUE) %>%
+ # kableExtra::kable_styling(latex_options = "repeat_header")
+} else {
+ top_genes_kable <- top_genes_table %>%
+ vthemes::pretty_kable() %>%
+ kableExtra::kable_styling(latex_options = "repeat_header")
+ vthemes::subchunkify(
+ top_genes_kable, i = sprintf("%s-table", fpath)
+ )
+}
+
+cat("\n\n## Prediction using Top Features\n\n")
+pred_fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "pred_results.csv"
+)
+pred_results <- data.table::fread(pred_fname) %>%
+ tibble::as_tibble() %>%
+ dplyr::mutate(method = fi) %>%
+ tidyr::pivot_wider(names_from = metric, values_from = metric_value)
+plt <- pred_results %>%
+ dplyr::filter(k <= 25) %>%
+ plot_metrics(
+ metric = c("rocauc", "accuracy"),
+ x_str = "k",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme,
+ inside_legend = FALSE
+ ) +
+ ggplot2::labs(x = "Number of Top Features (k)", y = "Mean Prediction Performance",
+ color = "Method", alpha = "Method")
+if (params$for_paper) {
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "tcga_brca_prediction.pdf"),
+ width = fig_width * 1.3, height = fig_height * .85)
+} else {
+ vthemes::subchunkify(
+ plot_fun(plt), i = sprintf("%s-pred", fpath),
+ fig_height = fig_height * .8, fig_width = fig_width * 1.5
+ )
+}
+```
+
+```{r eval = TRUE, results = "asis"}
+if (params$for_paper) {
+ manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399")
+ show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+ method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
+
+ # ccle rnaseq
+ fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
+ fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+ )
+ X_ccle <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv")
+ results_ccle <- data.table::fread(fname) %>%
+ dplyr::mutate(
+ fi = ifelse(fi == "MDI+", "MDI+_ridge", fi)
+ )
+
+ # tcga
+ fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
+ fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+ )
+ X_tcga <- data.table::fread("../../data/X_tcga_cleaned.csv")
+ results_tcga <- data.table::fread(fname)
+
+ # join results
+ keep_cols <- c("rep", "y_task", "model", "fi", "index", "var", "importance")
+ results <- dplyr::bind_rows(
+ results_ccle %>% dplyr::select(tidyselect::all_of(keep_cols)),
+ results_tcga %>% dplyr::mutate(y_task = "TCGA-BRCA") %>% dplyr::select(tidyselect::all_of(keep_cols))
+ ) %>%
+ tibble::as_tibble()
+ plt_ls <- plot_top_stability(
+ results,
+ group_id = "y_task",
+ top_r = 10,
+ show_max_features = 5,
+ base_method = "MDI+ (ridge)",
+ return_df = FALSE,
+ manual_color_palette = rev(manual_color_palette),
+ show_methods = rev(show_methods),
+ method_labels = rev(method_labels)
+ )
+
+ plt <- plt_ls$Summary +
+ ggplot2::labs(
+ y = "Drug",
+ x = "Number of Distinct Features in Top 10 Across 32 Training-Test Splits"
+ ) +
+ ggplot2::scale_x_continuous(
+ n.breaks = 6
+ )
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_top10.pdf"),
+ width = fig_width * 1.5, height = fig_height * 1.2)
+
+ keep_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
+ manual_color_palette_small <- manual_color_palette[show_methods %in% keep_methods]
+ show_methods_small <- show_methods[show_methods %in% keep_methods]
+ method_labels_small <- method_labels[show_methods %in% keep_methods]
+
+ plt_tcga <- plot_top_stability(
+ results_tcga,
+ top_r = 10,
+ show_max_features = 5,
+ base_method = "MDI+ (ridge)",
+ return_df = FALSE,
+ manual_color_palette = manual_color_palette_small,
+ show_methods = show_methods_small,
+ method_labels = method_labels_small
+ )$`Stability of Top Features` +
+ ggplot2::labs(title = "TCGA-BRCA") +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank())
+
+ plt_ls <- plot_top_stability(
+ results_ccle,
+ group_id = "y_task",
+ top_r = 10,
+ show_max_features = 5,
+ # varnames = colnames(X),
+ base_method = "MDI+ (ridge)",
+ return_df = FALSE,
+ manual_color_palette = manual_color_palette_small[show_methods_small != "MDI+_logistic_logloss"],
+ show_methods = show_methods_small[show_methods_small != "MDI+_logistic_logloss"],
+ method_labels = method_labels_small[show_methods_small != "MDI+_logistic_logloss"]
+ )
+
+ # keep_drugs <- c("Panobinostat", "Lapatinib", "L-685458")
+ keep_drugs <- c("PD-0325901", "Panobinostat", "L-685458")
+ small_plt_ls <- list()
+ for (drug in keep_drugs) {
+ plt <- plt_ls[[drug]]
+ if (drug != keep_drugs[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ small_plt_ls[[drug]] <- plt +
+ ggplot2::guides(fill = "none")
+ }
+ small_plt_ls[["TCGA-BRCA"]] <- plt_tcga
+ # small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + small_plt_ls[[4]] +
+ # patchwork::plot_layout(nrow = 1, ncol = 4, guides = "collect")
+ plt <- patchwork::wrap_plots(small_plt_ls, nrow = 1, ncol = 4, guides = "collect") &
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(hjust = 0.5),
+ panel.background = ggplot2::element_rect(fill = "white"),
+ panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
+ )
+ ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_select_features.pdf"),
+ width = fig_width * 1, height = fig_height * 0.5)
+
+ plt <- small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + patchwork::plot_spacer() + small_plt_ls[[4]] +
+ patchwork::plot_layout(nrow = 1, widths = c(1, 1, 1, 0.1, 1), guides = "collect")
+ # plt
+ # grid::grid.draw(
+ # grid::linesGrob(x = grid::unit(c(0.68, 0.68), "npc"),
+ # y = grid::unit(c(0.06, 0.98), "npc"))
+ # )
+ # # ggplot2::ggsave(
+ # # file.path(figures_dir, "casestudy_stability_select_features.pdf"),
+ # # units = "in", width = 14, height = 4
+ # # )
+}
+
+```
+
+
+# Misspecified Models {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
+legend_position <- c(0.73, 0.35)
+
+vary_param_name <- "heritability_sample_row_n"
+y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
+
+remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
+keep_legend_x_models <- c("enhancer")
+keep_legend_y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s \n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp_omitted_vars", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "misspecified_regression_sims",
+ sprintf("misspecified_regression_sims_%s_%s_%s_errbars.pdf",
+ y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+}
+
+```
+
+
+# Varying Sparsity {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
+legend_position <- c(0.73, 0.7)
+
+y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+x_models <- c("juvenile", "splicing")
+
+remove_x_axis_models <- c("splicing")
+keep_legend_x_models <- c("splicing")
+keep_legend_y_models <- c("linear", "linear_lss_3m_2r")
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s \n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ x_str <- ifelse(y_model == "linear", "s", "m")
+ vary_param_name <- sprintf("heritability_%s", x_str)
+ fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_sparsity.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = x_str,
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = toupper(x_str), y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "varying_sparsity",
+ sprintf("regression_sims_sparsity_%s_%s_%s_errbars.pdf",
+ y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = 3
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = x_str,
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = toupper(x_str), y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+}
+
+```
+
+
+# Varying \# Features {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
+manual_color_palette <- color_df$color[keep_methods]
+show_methods <- color_df$name[keep_methods]
+method_labels <- color_df$label[keep_methods]
+alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
+legend_position <- c(0.73, 0.4)
+
+vary_param_name <- "heritability_sample_col_n"
+y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+x_models <- c("ccle_rnaseq")
+
+remove_x_axis_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r")
+keep_legend_x_models <- c("ccle_rnaseq")
+keep_legend_y_models <- c("linear")
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s \n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_p.", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_col_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Number of Features", y = metric_name,
+ color = "Method", alpha = "Method"
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (y_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "varying_p",
+ sprintf("regression_sims_vary_p_%s_%s_%s_errbars.pdf",
+ y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_col_n",
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Number of Features", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+}
+
+```
+
+
+# Prediction Results {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+fpaths <- c(
+ "mdi_plus.prediction_sims.ccle_rnaseq_regression-",
+ "mdi_plus.prediction_sims.enhancer_classification-",
+ "mdi_plus.prediction_sims.splicing_classification-",
+ "mdi_plus.prediction_sims.juvenile_classification-",
+ "mdi_plus.prediction_sims.tcga_brca_classification-"
+)
+
+keep_models <- c("RF", "RF-ridge", "RF-lasso", "RF-logistic")
+prediction_metrics <- c(
+ "r2", "explained_variance", "mean_squared_error", "mean_absolute_error",
+ "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
+)
+
+results_ls <- list()
+for (fpath in fpaths) {
+ fname <- file.path(
+ results_dir,
+ fpath,
+ sprintf("seed%s", seed),
+ "results.csv"
+ )
+ results <- data.table::fread(fname) %>%
+ reformat_results(prediction = TRUE)
+
+ if (!("y_task" %in% colnames(results))) {
+ results <- results %>%
+ dplyr::mutate(y_task = "Results")
+ }
+
+ plt_df <- results %>%
+ tidyr::pivot_longer(
+ cols = tidyselect::any_of(prediction_metrics),
+ names_to = "metric", values_to = "value"
+ ) %>%
+ dplyr::group_by(y_task, model, metric) %>%
+ dplyr::summarise(
+ mean = mean(value),
+ sd = sd(value),
+ n = dplyr::n()
+ ) %>%
+ dplyr::mutate(
+ se = sd / sqrt(n)
+ )
+
+ if (fpath == "mdi_plus.prediction_sims.ccle_rnaseq_regression") {
+ tab <- plt_df %>%
+ dplyr::mutate(
+ value = sprintf("%.3f (%.3f)", mean, se)
+ ) %>%
+ tidyr::pivot_wider(id_cols = model, names_from = "metric", values_from = "value") %>%
+ dplyr::arrange(dplyr::desc(accuracy)) %>%
+ dplyr::select(model, accuracy, f1, rocauc, prauc, precision, recall) %>%
+ vthemes::pretty_kable(format = "latex")
+ }
+
+ if (fpath != "mdi_plus.prediction_sims.ccle_rnaseq_regression-") {
+ results_ls[[fpath]] <- plt_df %>%
+ dplyr::filter(model %in% c("RF", "RF-logistic")) %>%
+ dplyr::select(model, metric, mean) %>%
+ tidyr::pivot_wider(id_cols = metric, names_from = "model", values_from = "mean") %>%
+ dplyr::mutate(diff = `RF-logistic` - RF,
+ percent_diff = (`RF-logistic` - RF) / abs(RF) * 100)
+ } else {
+ results_ls[[fpath]] <- plt_df %>%
+ dplyr::filter(model %in% c("RF", "RF-ridge")) %>%
+ dplyr::select(y_task, model, metric, mean) %>%
+ tidyr::pivot_wider(id_cols = c(y_task, metric), names_from = "model", values_from = "mean") %>%
+ dplyr::mutate(diff = `RF-ridge` - RF,
+ percent_diff = (`RF-ridge` - RF) / abs(RF) * 100)
+ }
+}
+
+plt_reg <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
+ dplyr::filter(metric == "r2") %>%
+ dplyr::filter(RF > 0.1) %>%
+ dplyr::mutate(
+ metric = ifelse(metric == "r2", "R-squared", metric)
+ ) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
+ ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
+ ggplot2::geom_point(size = 3) +
+ ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
+ ggrepel::geom_label_repel(fill = "white") +
+ ggplot2::facet_wrap(~ metric, scales = "free") +
+ custom_theme
+
+plt_reg_all <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
+ dplyr::filter(metric == "r2") %>%
+ dplyr::mutate(
+ metric = ifelse(metric == "r2", "R-squared", metric)
+ ) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
+ ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
+ ggplot2::geom_point(size = 3) +
+ ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
+ ggrepel::geom_label_repel(fill = "white") +
+ ggplot2::facet_wrap(~ metric, scales = "free") +
+ custom_theme
+
+if (params$for_paper) {
+ ggplot2::ggsave(
+ plot = plt_reg_all,
+ file.path(figures_dir, "prediction_results_appendix.pdf"),
+ units = "in", width = 8, height = fig_height * 0.75
+ )
+}
+
+plt_ls <- list(reg = plt_reg)
+for (m in c("f1", "prauc")) {
+ plt_df <- dplyr::bind_rows(results_ls, .id = "dataset") %>%
+ dplyr::filter(metric == m) %>%
+ dplyr::mutate(
+ dataset = stringr::str_remove(dataset, "^mdi_plus\\.prediction_sims\\.") %>%
+ stringr::str_remove("_classification-$") %>%
+ stringr::str_to_title() %>%
+ stringr::str_replace("Tcga_brca", "TCGA BRCA"),
+ diff = ifelse(metric == "logloss", -diff, diff),
+ percent_diff = ifelse(metric == "logloss", -percent_diff, percent_diff),
+ metric = forcats::fct_recode(metric,
+ "Accuracy" = "accuracy",
+ "F1" = "f1",
+ "Negative Log-loss" = "logloss",
+ "AUPRC" = "prauc",
+ "AUROC" = "rocauc")
+ )
+ plt_ls[[m]] <- plt_df %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = RF, y = percent_diff, label = dataset) +
+ ggplot2::labs(x = sprintf("RF Test %s", plt_df$metric[1]),
+ y = "% Change Using RF+ (logistic)") +
+ ggplot2::geom_point(size = 3) +
+ ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
+ ggrepel::geom_label_repel(point.padding = 0.5, fill = "white") +
+ ggplot2::facet_wrap(~ metric, scales = "free", nrow = 1) +
+ custom_theme
+}
+
+plt_ls[[3]] <- plt_ls[[3]] +
+ ggplot2::theme(axis.title.y = ggplot2::element_blank())
+
+plt <- plt_ls[[1]] + patchwork::plot_spacer() + plt_ls[[2]] + plt_ls[[3]] +
+ patchwork::plot_layout(nrow = 1, widths = c(1, 0.3, 1, 1))
+
+# plt
+# grid::grid.draw(
+# grid::linesGrob(x = grid::unit(c(0.36, 0.36), "npc"),
+# y = grid::unit(c(0.06, 0.98), "npc"))
+# )
+# # ggplot2::ggsave(
+# # file.path(figures_dir, "prediction_results_main.pdf"),
+# # units = "in", width = 14, height = fig_height * 0.75
+# # )
+
+vthemes::subchunkify(plt, i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+chunk_idx <- chunk_idx + 1
+```
+
+
+# MDI+ Modeling Choices {.tabset .tabset-vmodern}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c(manual_color_palette_choices[method_labels_choices == "MDI"],
+ manual_color_palette_choices[method_labels_choices == "MDI-oob"],
+ "#AF4D98", "#FFD92F", "#FC8D62", "black")
+show_methods <- c("MDI_RF", "MDI-oob_RF", "MDI+_raw_RF", "MDI+_loo_RF", "MDI+_ols_raw_loo_RF", "MDI+_ridge_raw_loo_RF")
+method_labels <- c("MDI", "MDI-oob", "MDI+ (raw only)", "MDI+ (loo only)", "MDI+ (raw+loo only)", "MDI+ (ridge+raw+loo)")
+alpha_values <- rep(1, length(method_labels))
+legend_position <- c(0.63, 0.4)
+
+vary_param_name <- "heritability_sample_row_n"
+y_models <- c("linear", "hier_poly_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
+min_samples_per_leafs <- c(5, 1)
+
+remove_x_axis_models <- c("enhancer")
+keep_legend_x_models <- c("enhancer")
+keep_legend_y_models <- c("linear")
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
+ plt_ls <- list()
+ for (min_samples_per_leaf in min_samples_per_leafs) {
+ cat(sprintf("\n\n#### min_samples_per_leaf = %s \n\n", min_samples_per_leaf))
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ metric_name <- dplyr::case_when(
+ metric == "rocauc" ~ "AUROC",
+ metric == "prauc" ~ "PRAUC"
+ )
+ fname <- file.path(results_dir,
+ sprintf("mdi_plus.other_regression_sims.modeling_choices_min_samples%s.%s", min_samples_per_leaf, sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (!file.exists(sprintf("%s.csv", fname))) {
+ next
+ }
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ ) +
+ ggplot2::theme(
+ legend.text = ggplot2::element_text(size = 12)
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "modeling_choices",
+ sprintf("regression_sims_choices_min_samples%s_%s_%s_%s_errbars.pdf",
+ min_samples_per_leaf, y_model, x_model, metric)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = metric,
+ x_str = "sample_row_n",
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+ }
+}
+
+```
+
+
+# MDI+ GLM/Metric Choices {.tabset .tabset-vmodern}
+
+## Held-out Test Prediction Scores {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("#A5AE9E", "black", "brown")
+show_methods <- c("RF", "RF-ridge", "RF-lasso")
+method_labels <- c("RF", "RF+ridge", "RF+lasso")
+alpha_values <- rep(1, length(method_labels))
+legend_position <- c(0.63, 0.4)
+
+vary_param_name <- "heritability_sample_row_n"
+y_models <- c("linear", "hier_poly_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
+prediction_metrics <- c(
+ "r2" #, "mean_absolute_error"
+ # "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
+)
+
+remove_x_axis_models <- c("enhancer")
+keep_legend_x_models <- c("enhancer")
+keep_legend_y_models <- c("linear")
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
+ plt_ls <- list()
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ fname <- file.path(results_dir,
+ sprintf("mdi_plus.glm_metric_choices_sims.regression_prediction_sims.%s", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (!file.exists(sprintf("%s.csv", fname))) {
+ next
+ }
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results(prediction = TRUE)
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results(prediction = TRUE)
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ for (m in prediction_metrics) {
+ metric_name <- dplyr::case_when(
+ m == "r2" ~ "R-squared",
+ m == "mae" ~ "Mean Absolute Error"
+ )
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = TRUE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method"
+ ) +
+ ggplot2::theme(
+ legend.text = ggplot2::element_text(size = 12)
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (x_model %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none")
+ }
+ plt_ls[[as.character(h)]] <- plt
+ }
+ plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
+ ggplot2::ggsave(
+ file.path(figures_dir, "glm_metric_choices",
+ sprintf("regression_sims_glm_metric_choices_prediction_%s_%s_%s_errbars.pdf",
+ y_model, x_model, m)),
+ plot = plt, units = "in", width = 14, height = height
+ )
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = "heritability_name",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method",
+ title = sprintf("%s", sim_title)
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.3)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+ }
+}
+
+```
+
+## Stability Scores - Regression {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("black", "black", "brown", "brown", "gray",
+ manual_color_palette_choices[method_labels_choices == "MDI"],
+ manual_color_palette_choices[method_labels_choices == "MDI-oob"])
+show_methods <- c("MDI+_ridge_r2_RF", "MDI+_ridge_neg_mae_RF", "MDI+_lasso_r2_RF", "MDI+_lasso_neg_mae_RF", "MDI+_ensemble_RF", "MDI_RF", "MDI-oob_RF")
+method_labels <- c("MDI+ (ridge, r-squared)", "MDI+ (ridge, MAE)", "MDI+ (lasso, r-squared)", "MDI+ (lasso, MAE)", "MDI+ (ensemble)", "MDI", "MDI-oob")
+manual_linetype_palette <- c(1, 2, 1, 2, 1, 1, 1)
+alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
+legend_position <- c(0.63, 0.4)
+
+vary_param_name <- "heritability_sample_row_n"
+# y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
+# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
+y_models <- c("linear", "hier_poly_3m_2r")
+x_models <- c("enhancer", "ccle_rnaseq")
+# stability_metrics <- c("tauAP", "RBO")
+stability_metrics <- c("RBO")
+
+remove_x_axis_models <- metric
+keep_legend_x_models <- x_models
+keep_legend_y_models <- y_models
+keep_legend_metrics <- metric
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ fname <- file.path(results_dir,
+ sprintf("mdi_plus.glm_metric_choices_sims.regression_sims.%s", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (!file.exists(sprintf("%s.csv", fname))) {
+ next
+ }
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ plt_ls <- list()
+ for (m in c(metric, stability_metrics)) {
+ metric_name <- dplyr::case_when(
+ m == "rocauc" ~ "AUROC",
+ m == "prauc" ~ "PRAUC",
+ TRUE ~ m
+ )
+ if (params$for_paper) {
+ for (h in heritabilities) {
+ plt <- results %>%
+ dplyr::filter(heritability == !!h) %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ linetype_str = "method",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = FALSE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method", linetype = "Method"
+ ) +
+ ggplot2::scale_linetype_manual(
+ values = manual_linetype_palette, labels = method_labels
+ ) +
+ ggplot2::theme(
+ legend.text = ggplot2::element_text(size = 12),
+ legend.key.width = ggplot2::unit(1.9, "cm")
+ )
+ if (h != heritabilities[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (m %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == heritabilities[length(heritabilities)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models) &
+ (metric %in% keep_legend_metrics))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
+ }
+ plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
+ }
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = "heritability_name",
+ linetype_str = "method",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method", linetype = "Method",
+ title = sprintf("%s", sim_title)
+ ) +
+ ggplot2::scale_linetype_manual(
+ values = manual_linetype_palette, labels = method_labels
+ ) +
+ ggplot2::theme(
+ legend.key.width = ggplot2::unit(1.9, "cm")
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.5)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+ if (params$for_paper) {
+ nrows <- length(c(metric, stability_metrics))
+ plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
+ ggplot2::ggsave(
+ file.path(figures_dir, "glm_metric_choices",
+ sprintf("regression_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
+ y_model, x_model)),
+ plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
+ )
+ }
+ }
+}
+
+```
+
+## Stability Scores - Classification {.tabset .tabset-pills .tabset-square}
+
+```{r eval = TRUE, results = "asis"}
+manual_color_palette <- c("#9B5DFF", "#9B5DFF", "black",
+ manual_color_palette_choices[method_labels_choices == "MDI"],
+ manual_color_palette_choices[method_labels_choices == "MDI-oob"])
+show_methods <- c("MDI+_logistic_ridge_logloss_RF", "MDI+_logistic_ridge_auroc_RF", "MDI+_ridge_r2_RF", "MDI_RF", "MDI-oob_RF")
+method_labels <- c("MDI+ (logistic, log-loss)", "MDI+ (logistic, AUROC)", "MDI+ (ridge, r-squared)", "MDI", "MDI-oob")
+manual_linetype_palette <- c(1, 2, 1, 1, 1)
+alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
+legend_position <- c(0.63, 0.4)
+
+vary_param_name <- "frac_label_corruption_sample_row_n"
+# y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
+# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
+y_models <- c("logistic", "hier_poly_3m_2r_logistic")
+x_models <- c("juvenile", "splicing")
+# stability_metrics <- c("tauAP", "RBO")
+stability_metrics <- "RBO"
+
+remove_x_axis_models <- metric
+keep_legend_x_models <- x_models
+keep_legend_y_models <- y_models
+keep_legend_metrics <- metric
+
+for (y_model in y_models) {
+ cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
+ for (x_model in x_models) {
+ cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
+ sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
+ sim_title <- dplyr::case_when(
+ x_model == "ccle_rnaseq" ~ "CCLE",
+ x_model == "splicing" ~ "Splicing",
+ x_model == "enhancer" ~ "Enhancer",
+ x_model == "juvenile" ~ "Juvenile"
+ )
+ fname <- file.path(results_dir,
+ sprintf("mdi_plus.glm_metric_choices_sims.classification_sims.%s", sim_name),
+ paste0("varying_", vary_param_name),
+ paste0("seed", seed), "results")
+ if (!file.exists(sprintf("%s.csv", fname))) {
+ next
+ }
+ if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
+ results <- readRDS(sprintf("%s.rds", fname))
+ if (length(setdiff(show_methods, unique(results$method))) > 0) {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ } else {
+ results <- data.table::fread(sprintf("%s.csv", fname)) %>%
+ reformat_results()
+ saveRDS(results, sprintf("%s.rds", fname))
+ }
+ plt_ls <- list()
+ for (m in c(metric, stability_metrics)) {
+ metric_name <- dplyr::case_when(
+ m == "rocauc" ~ "AUROC",
+ m == "prauc" ~ "PRAUC",
+ TRUE ~ m
+ )
+ if (params$for_paper) {
+ for (h in frac_label_corruptions) {
+ plt <- results %>%
+ dplyr::filter(frac_label_corruption_name == !!h) %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = NULL,
+ linetype_str = "method",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ legend_position = legend_position,
+ custom_theme = custom_theme,
+ inside_legend = FALSE
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method", linetype = "Method"
+ ) +
+ ggplot2::scale_linetype_manual(
+ values = manual_linetype_palette, labels = method_labels
+ ) +
+ ggplot2::theme(
+ legend.text = ggplot2::element_text(size = 12),
+ legend.key.width = ggplot2::unit(1.9, "cm")
+ )
+ if (h != frac_label_corruptions[1]) {
+ plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
+ }
+ if (m %in% remove_x_axis_models) {
+ height <- 2.72
+ plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
+ } else {
+ height <- 3
+ }
+ if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) &
+ (x_model %in% keep_legend_x_models) &
+ (y_model %in% keep_legend_y_models) &
+ (metric %in% keep_legend_metrics))) {
+ plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
+ }
+ plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
+ }
+ } else {
+ plt <- results %>%
+ plot_metrics(
+ metric = m,
+ x_str = "sample_row_n",
+ facet_str = "frac_label_corruption_name",
+ linetype_str = "method",
+ point_size = point_size,
+ line_size = line_size,
+ errbar_width = errbar_width,
+ manual_color_palette = manual_color_palette,
+ show_methods = show_methods,
+ method_labels = method_labels,
+ alpha_values = alpha_values,
+ custom_theme = custom_theme
+ ) +
+ ggplot2::labs(
+ x = "Sample Size", y = metric_name,
+ color = "Method", alpha = "Method", linetype = "Method",
+ title = sprintf("%s", sim_title)
+ ) +
+ ggplot2::scale_linetype_manual(
+ values = manual_linetype_palette, labels = method_labels
+ ) +
+ ggplot2::theme(
+ legend.key.width = ggplot2::unit(1.9, "cm")
+ )
+ vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
+ fig_height = fig_height, fig_width = fig_width * 1.5)
+ chunk_idx <- chunk_idx + 1
+ }
+ }
+ if (params$for_paper) {
+ nrows <- length(c(metric, stability_metrics))
+ plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
+ ggplot2::ggsave(
+ file.path(figures_dir, "glm_metric_choices",
+ sprintf("classification_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
+ y_model, x_model)),
+ plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
+ )
+ }
+ }
+}
+
+```
diff --git a/feature_importance/notebooks/mdi_plus/01_paper_figures.html b/feature_importance/notebooks/mdi_plus/01_paper_figures.html
new file mode 100644
index 0000000..8561f39
--- /dev/null
+++ b/feature_importance/notebooks/mdi_plus/01_paper_figures.html
@@ -0,0 +1,5294 @@
+
+
+
+
+
+
+
+
+
+
+
+ MDI+ Simulation Results Summary
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Regression Simulations
+
+
linear
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
lss_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
linear_lss_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
+
Classification Simulations
+
+
logistic
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
lss_3m_2r_logistic
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
hier_poly_3m_2r_logistic
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
linear_lss_3m_2r_logistic
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
+
Robust Simulations
+
+
linear
+
+
enhancer
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
ccle_rnaseq
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
+
lss_3m_2r
+
+
enhancer
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
ccle_rnaseq
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
ccle_rnaseq
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
+
linear_lss_3m_2r
+
+
enhancer
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
ccle_rnaseq
+
+
Mean Shift = 10
+
+
+
+
Mean Shift = 25
+
+
+
+
+
+
+
Correlation Bias Simulations
+
+
Main Figures
+
+
PVE = 0.1
+
+
+
+
PVE = 0.4
+
+
+
+
Number of RF Splits
+
+
+
+
+
+
+
Entropy Bias Simulations
+
+
Main Figures
+
+
Average Rank
+
+
+
+
Number of RF Splits
+
+
+
+
+
+
+
CCLE Case Study
+
+
Summary
+
+
+
+
Table of Top Genes
+
+
+
+
+
+
+ |
+
+MDI+ (ridge)
+ |
+
+TreeSHAP
+ |
+
+MDI
+ |
+
+MDA
+ |
+
+MDI-oob
+ |
+
+
+
+
+
+17-AAG
+ |
+
+
+
+1
+ |
+
+PCSK1N (1.47)
+ |
+
+PCSK1N (2.19)
+ |
+
+PCSK1N (2.19)
+ |
+
+PCSK1N (2.5)
+ |
+
+PCSK1N (9.16)
+ |
+
+
+
+2
+ |
+
+MMP24 (3.41)
+ |
+
+MMP24 (4.97)
+ |
+
+MMP24 (4.06)
+ |
+
+NQO1 (9.03)
+ |
+
+MMP24 (194.5)
+ |
+
+
+
+3
+ |
+
+RP11-109D20.2 (4.59)
+ |
+
+ZSCAN18 (6.94)
+ |
+
+ZSCAN18 (8.56)
+ |
+
+ZNF667-AS1 (26.59)
+ |
+
+ZNF667-AS1 (241.38)
+ |
+
+
+
+4
+ |
+
+ZSCAN18 (8.09)
+ |
+
+RP11-109D20.2 (7.41)
+ |
+
+RP11-109D20.2 (10.09)
+ |
+
+ZSCAN18 (44.03)
+ |
+
+RP11-109D20.2 (525.22)
+ |
+
+
+
+5
+ |
+
+NQO1 (8.84)
+ |
+
+NQO1 (9.53)
+ |
+
+NQO1 (11.81)
+ |
+
+TST (49.38)
+ |
+
+SH3BP1 (587.94)
+ |
+
+
+
+AEW541
+ |
+
+
+
+1
+ |
+
+TXNDC5 (1.41)
+ |
+
+TCEAL4 (1.62)
+ |
+
+TXNDC5 (1.75)
+ |
+
+TXNDC5 (1.5)
+ |
+
+TCEAL4 (5.59)
+ |
+
+
+
+2
+ |
+
+ATP8B2 (4.34)
+ |
+
+TXNDC5 (3.53)
+ |
+
+ATP8B2 (3.84)
+ |
+
+ATP8B2 (4.69)
+ |
+
+IQGAP2 (238.8)
+ |
+
+
+
+3
+ |
+
+VAV2 (6.03)
+ |
+
+ATP8B2 (8.38)
+ |
+
+VAV2 (5.41)
+ |
+
+VAV2 (5.84)
+ |
+
+RP11-343H19.2 (303.25)
+ |
+
+
+
+4
+ |
+
+TNFRSF17 (8.53)
+ |
+
+VAV2 (10.47)
+ |
+
+TCEAL4 (9.44)
+ |
+
+TCEAL4 (6.5)
+ |
+
+TXNDC5 (312.62)
+ |
+
+
+
+5
+ |
+
+TCEAL4 (9.03)
+ |
+
+PLEKHF1 (19.25)
+ |
+
+TNFRSF17 (9.69)
+ |
+
+TNFRSF17 (13.56)
+ |
+
+ATP8B2 (318.59)
+ |
+
+
+
+AZD0530
+ |
+
+
+
+1
+ |
+
+PRSS57 (5.16)
+ |
+
+SYTL1 (17.62)
+ |
+
+PRSS57 (7.69)
+ |
+
+PRSS57 (12.09)
+ |
+
+VTN (105.69)
+ |
+
+
+
+2
+ |
+
+SYTL1 (12.31)
+ |
+
+PRSS57 (32.5)
+ |
+
+SYTL1 (42.94)
+ |
+
+DDAH2 (440.16)
+ |
+
+SYTL1 (216.62)
+ |
+
+
+
+3
+ |
+
+STXBP2 (15.38)
+ |
+
+SFTA1P (36.5)
+ |
+
+NFE2 (43.06)
+ |
+
+SLC16A9 (484.34)
+ |
+
+STXBP2 (245.31)
+ |
+
+
+
+4
+ |
+
+NFE2 (23.12)
+ |
+
+STXBP2 (62.59)
+ |
+
+STXBP2 (61.62)
+ |
+
+STXBP2 (486.09)
+ |
+
+ZBED2 (466.48)
+ |
+
+
+
+5
+ |
+
+THEM4 (34.41)
+ |
+
+CLDN16 (67.28)
+ |
+
+SLC16A9 (61.81)
+ |
+
+RAPGEF3 (514.2)
+ |
+
+DDAH2 (472.06)
+ |
+
+
+
+AZD6244
+ |
+
+
+
+1
+ |
+
+LYZ (1.66)
+ |
+
+TOR4A (2.31)
+ |
+
+LYZ (1.72)
+ |
+
+LYZ (2.53)
+ |
+
+LYZ (2.09)
+ |
+
+
+
+2
+ |
+
+SPRY2 (2.34)
+ |
+
+SPRY2 (3.31)
+ |
+
+RP11-1143G9.4 (3.59)
+ |
+
+SPRY2 (2.59)
+ |
+
+RP11-1143G9.4 (3.59)
+ |
+
+
+
+3
+ |
+
+RP11-1143G9.4 (2.84)
+ |
+
+LYZ (3.69)
+ |
+
+SPRY2 (3.91)
+ |
+
+TOR4A (3.25)
+ |
+
+TOR4A (3.66)
+ |
+
+
+
+4
+ |
+
+ETV4 (5.22)
+ |
+
+ETV4 (5.19)
+ |
+
+TOR4A (6.41)
+ |
+
+RP11-1143G9.4 (4.75)
+ |
+
+SPRY2 (4.5)
+ |
+
+
+
+5
+ |
+
+TOR4A (6.41)
+ |
+
+RP11-1143G9.4 (6.84)
+ |
+
+RNF125 (6.66)
+ |
+
+ETV4 (6.91)
+ |
+
+ETV4 (6.34)
+ |
+
+
+
+Erlotinib
+ |
+
+
+
+1
+ |
+
+CDH3 (1.47)
+ |
+
+CDH3 (1.84)
+ |
+
+CDH3 (2.03)
+ |
+
+CDH3 (1.97)
+ |
+
+CDH3 (1.88)
+ |
+
+
+
+2
+ |
+
+RP11-615I2.2 (2.28)
+ |
+
+RP11-615I2.2 (3.28)
+ |
+
+RP11-615I2.2 (2.88)
+ |
+
+RP11-615I2.2 (3.53)
+ |
+
+RP11-615I2.2 (3.16)
+ |
+
+
+
+3
+ |
+
+EGFR (4.34)
+ |
+
+SPRR1A (3.97)
+ |
+
+EGFR (3.97)
+ |
+
+SPRR1A (6.25)
+ |
+
+SPRR1A (3.84)
+ |
+
+
+
+4
+ |
+
+SPRR1A (4.44)
+ |
+
+SYTL1 (7.84)
+ |
+
+SPRR1A (4.19)
+ |
+
+GJB3 (8.78)
+ |
+
+EGFR (8.31)
+ |
+
+
+
+5
+ |
+
+GJB3 (7.44)
+ |
+
+EGFR (8.69)
+ |
+
+KRT16 (11.41)
+ |
+
+EGFR (9.31)
+ |
+
+SYTL1 (8.72)
+ |
+
+
+
+Irinotecan
+ |
+
+
+
+1
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+
+
+2
+ |
+
+S100A16 (4.12)
+ |
+
+S100A16 (3.75)
+ |
+
+S100A16 (3.84)
+ |
+
+S100A16 (3.25)
+ |
+
+WWTR1 (6.03)
+ |
+
+
+
+3
+ |
+
+IFITM10 (4.19)
+ |
+
+IFITM10 (4.09)
+ |
+
+WWTR1 (4.28)
+ |
+
+WWTR1 (4.12)
+ |
+
+TRIM16L (150.38)
+ |
+
+
+
+4
+ |
+
+WWTR1 (4.94)
+ |
+
+WWTR1 (4.78)
+ |
+
+IFITM10 (8.03)
+ |
+
+RP11-359P5.1 (8.22)
+ |
+
+IFITM10 (163.44)
+ |
+
+
+
+5
+ |
+
+PPIC (7.81)
+ |
+
+PPIC (10.22)
+ |
+
+RP11-359P5.1 (8.47)
+ |
+
+IFITM10 (8.41)
+ |
+
+S100A16 (182.19)
+ |
+
+
+
+L-685458
+ |
+
+
+
+1
+ |
+
+PXK (2.03)
+ |
+
+DEF6 (4.62)
+ |
+
+PXK (2.03)
+ |
+
+PXK (3)
+ |
+
+PXK (2.94)
+ |
+
+
+
+2
+ |
+
+DEF6 (4.28)
+ |
+
+PXK (4.94)
+ |
+
+DEF6 (4.38)
+ |
+
+IKZF1 (3.44)
+ |
+
+CXorf21 (4.62)
+ |
+
+
+
+3
+ |
+
+CXorf21 (4.84)
+ |
+
+CXorf21 (5.44)
+ |
+
+CXorf21 (5.62)
+ |
+
+CXorf21 (5.75)
+ |
+
+DEF6 (4.66)
+ |
+
+
+
+4
+ |
+
+IKZF1 (6.03)
+ |
+
+IKZF1 (5.75)
+ |
+
+IKZF1 (7.91)
+ |
+
+DEF6 (10.31)
+ |
+
+IKZF1 (5.47)
+ |
+
+
+
+5
+ |
+
+RP11-359P5.1 (9.09)
+ |
+
+RP11-359P5.1 (9.94)
+ |
+
+RP11-359P5.1 (12.81)
+ |
+
+CTNNA1 (13.88)
+ |
+
+RP11-359P5.1 (10.06)
+ |
+
+
+
+LBW242
+ |
+
+
+
+1
+ |
+
+SERPINB6 (1.12)
+ |
+
+SERPINB6 (1)
+ |
+
+SERPINB6 (1.66)
+ |
+
+SERPINB6 (1.31)
+ |
+
+SERPINB6 (1.56)
+ |
+
+
+
+2
+ |
+
+RGS14 (5.12)
+ |
+
+RGS14 (6.66)
+ |
+
+RGS14 (3.66)
+ |
+
+GPT2 (45.62)
+ |
+
+HERC5 (54.31)
+ |
+
+
+
+3
+ |
+
+HERC5 (5.41)
+ |
+
+MAGEC1 (7.5)
+ |
+
+MAGEC1 (5.53)
+ |
+
+GBP1 (179.28)
+ |
+
+ITGA1 (73.34)
+ |
+
+
+
+4
+ |
+
+MAGEC1 (7.62)
+ |
+
+ITGA1 (10.53)
+ |
+
+GBP1 (5.78)
+ |
+
+ZNF32 (222.03)
+ |
+
+PTGS1 (296.62)
+ |
+
+
+
+5
+ |
+
+GBP1 (8.22)
+ |
+
+HERC5 (12.41)
+ |
+
+CCL2 (13.5)
+ |
+
+IGSF3 (257.88)
+ |
+
+GPT2 (316.12)
+ |
+
+
+
+Lapatinib
+ |
+
+
+
+1
+ |
+
+ERBB2 (1.06)
+ |
+
+ERBB2 (1.5)
+ |
+
+ERBB2 (1.41)
+ |
+
+ERBB2 (1.69)
+ |
+
+ERBB2 (1.47)
+ |
+
+
+
+2
+ |
+
+PGAP3 (3.09)
+ |
+
+NA (6.81)
+ |
+
+PGAP3 (3.44)
+ |
+
+PGAP3 (8.03)
+ |
+
+NA (4.31)
+ |
+
+
+
+3
+ |
+
+NA (5.03)
+ |
+
+PGAP3 (12.41)
+ |
+
+IKBIP (6.19)
+ |
+
+C2orf54 (13.09)
+ |
+
+PGAP3 (14.03)
+ |
+
+
+
+4
+ |
+
+C2orf54 (6.91)
+ |
+
+DPYSL2 (16.16)
+ |
+
+NA (6.22)
+ |
+
+DPYSL2 (15.41)
+ |
+
+PKP3 (20.31)
+ |
+
+
+
+5
+ |
+
+IKBIP (8.28)
+ |
+
+PKP3 (16.47)
+ |
+
+C2orf54 (7.41)
+ |
+
+EMP3 (20.47)
+ |
+
+EMP3 (22.38)
+ |
+
+
+
+Nilotinib
+ |
+
+
+
+1
+ |
+
+SPN (1.25)
+ |
+
+SPN (1.81)
+ |
+
+SPN (1.62)
+ |
+
+SPN (1.47)
+ |
+
+SPN (3.59)
+ |
+
+
+
+2
+ |
+
+GPC1 (3.5)
+ |
+
+GPC1 (4.38)
+ |
+
+GPC1 (3.5)
+ |
+
+GPC1 (3.44)
+ |
+
+SELPLG (9.03)
+ |
+
+
+
+3
+ |
+
+TRDC (6.62)
+ |
+
+SELPLG (7.5)
+ |
+
+TRDC (10.16)
+ |
+
+SELPLG (9.97)
+ |
+
+KLF13 (26.22)
+ |
+
+
+
+4
+ |
+
+SELPLG (6.78)
+ |
+
+KLF13 (16.97)
+ |
+
+LMO2 (10.44)
+ |
+
+TRDC (10.22)
+ |
+
+BCL2 (51.09)
+ |
+
+
+
+5
+ |
+
+LMO2 (9.44)
+ |
+
+TRDC (20.22)
+ |
+
+CISH (11.03)
+ |
+
+LMO2 (10.75)
+ |
+
+GPC1 (166.19)
+ |
+
+
+
+Nutlin-3
+ |
+
+
+
+1
+ |
+
+RP11-148O21.4 (1.41)
+ |
+
+MET (2.53)
+ |
+
+RP11-148O21.4 (1.59)
+ |
+
+RP11-148O21.4 (1.72)
+ |
+
+RAPGEF5 (37)
+ |
+
+
+
+2
+ |
+
+MET (2.53)
+ |
+
+RP11-148O21.4 (4.75)
+ |
+
+MET (4.06)
+ |
+
+LRRC16A (9.56)
+ |
+
+G6PD (63.53)
+ |
+
+
+
+3
+ |
+
+BLK (5.12)
+ |
+
+LAYN (6.41)
+ |
+
+BLK (5.25)
+ |
+
+BLK (10.97)
+ |
+
+MET (147.78)
+ |
+
+
+
+4
+ |
+
+LRRC16A (5.16)
+ |
+
+RPS27L (12.94)
+ |
+
+LRRC16A (5.97)
+ |
+
+MET (26.09)
+ |
+
+BLK (160.06)
+ |
+
+
+
+5
+ |
+
+LAT2 (7.34)
+ |
+
+ADD3 (21.03)
+ |
+
+LAT2 (7.78)
+ |
+
+LAYN (138.84)
+ |
+
+RP11-148O21.4 (164.94)
+ |
+
+
+
+PD-0325901
+ |
+
+
+
+1
+ |
+
+SPRY2 (1.16)
+ |
+
+SPRY2 (1.53)
+ |
+
+SPRY2 (1.75)
+ |
+
+SPRY2 (1.19)
+ |
+
+SPRY2 (1.75)
+ |
+
+
+
+2
+ |
+
+LYZ (2.72)
+ |
+
+ETV4 (2.88)
+ |
+
+LYZ (2.09)
+ |
+
+LYZ (2.88)
+ |
+
+LYZ (2.62)
+ |
+
+
+
+3
+ |
+
+ETV4 (2.72)
+ |
+
+LYZ (3.59)
+ |
+
+ETV4 (3.66)
+ |
+
+ETV4 (3.38)
+ |
+
+ETV4 (3.47)
+ |
+
+
+
+4
+ |
+
+RP11-1143G9.4 (4.34)
+ |
+
+TOR4A (4.56)
+ |
+
+RP11-1143G9.4 (4.53)
+ |
+
+TOR4A (4.72)
+ |
+
+TOR4A (4.88)
+ |
+
+
+
+5
+ |
+
+PLEKHG4B (5.62)
+ |
+
+PLEKHG4B (4.66)
+ |
+
+PLEKHG4B (5.31)
+ |
+
+RP11-1143G9.4 (5.38)
+ |
+
+PLEKHG4B (5.5)
+ |
+
+
+
+PD-0332991
+ |
+
+
+
+1
+ |
+
+SH2D3C (4.31)
+ |
+
+SH2D3C (6.81)
+ |
+
+SH2D3C (4.56)
+ |
+
+SH2D3C (8)
+ |
+
+KRT15 (10.56)
+ |
+
+
+
+2
+ |
+
+FMNL1 (6.56)
+ |
+
+FMNL1 (8.38)
+ |
+
+HSD3B7 (6.75)
+ |
+
+AL162151.3 (9.62)
+ |
+
+HSD3B7 (14.5)
+ |
+
+
+
+3
+ |
+
+HSD3B7 (6.59)
+ |
+
+AL162151.3 (11.59)
+ |
+
+FMNL1 (7.09)
+ |
+
+HSD3B7 (11.03)
+ |
+
+SEPT6 (16.44)
+ |
+
+
+
+4
+ |
+
+KRT15 (7.19)
+ |
+
+TWF1 (12.03)
+ |
+
+KRT15 (7.78)
+ |
+
+KRT15 (11.97)
+ |
+
+PPIC (17.03)
+ |
+
+
+
+5
+ |
+
+AL162151.3 (8.84)
+ |
+
+KRT15 (12.56)
+ |
+
+TWF1 (8.97)
+ |
+
+FMNL1 (16.34)
+ |
+
+AL162151.3 (18.34)
+ |
+
+
+
+PF2341066
+ |
+
+
+
+1
+ |
+
+ENAH (1.03)
+ |
+
+ENAH (1)
+ |
+
+ENAH (1.31)
+ |
+
+ENAH (1.12)
+ |
+
+ENAH (1.06)
+ |
+
+
+
+2
+ |
+
+SELPLG (2.09)
+ |
+
+SELPLG (2.81)
+ |
+
+SELPLG (2.06)
+ |
+
+SELPLG (2.28)
+ |
+
+SELPLG (2.47)
+ |
+
+
+
+3
+ |
+
+HGF (3.62)
+ |
+
+HGF (3.94)
+ |
+
+MET (5.44)
+ |
+
+HGF (7.03)
+ |
+
+CTD-2020K17.3 (10.31)
+ |
+
+
+
+4
+ |
+
+CTD-2020K17.3 (9.72)
+ |
+
+CTD-2020K17.3 (9.41)
+ |
+
+HGF (6)
+ |
+
+MET (10.69)
+ |
+
+HGF (14.06)
+ |
+
+
+
+5
+ |
+
+MET (10.53)
+ |
+
+MET (11.5)
+ |
+
+MLKL (12)
+ |
+
+CTD-2020K17.3 (11.88)
+ |
+
+DOK2 (14.41)
+ |
+
+
+
+PHA-665752
+ |
+
+
+
+1
+ |
+
+ARHGAP4 (1)
+ |
+
+ARHGAP4 (1.06)
+ |
+
+ARHGAP4 (1.06)
+ |
+
+ARHGAP4 (1.09)
+ |
+
+ARHGAP4 (1)
+ |
+
+
+
+2
+ |
+
+CTD-2020K17.3 (2.88)
+ |
+
+CTD-2020K17.3 (2.66)
+ |
+
+CTD-2020K17.3 (4.25)
+ |
+
+CTD-2020K17.3 (2.56)
+ |
+
+FMNL1 (4.22)
+ |
+
+
+
+3
+ |
+
+FMNL1 (4.88)
+ |
+
+FMNL1 (7.44)
+ |
+
+PFN2 (10.84)
+ |
+
+PFN2 (24.31)
+ |
+
+CTD-2020K17.3 (5.44)
+ |
+
+
+
+4
+ |
+
+PFN2 (8.06)
+ |
+
+PGPEP1 (13.28)
+ |
+
+FMNL1 (11.78)
+ |
+
+FMNL1 (33.12)
+ |
+
+INHBB (216.5)
+ |
+
+
+
+5
+ |
+
+PGPEP1 (10.12)
+ |
+
+INHBB (18.88)
+ |
+
+FDFT1 (18.66)
+ |
+
+MICB (59.69)
+ |
+
+PGPEP1 (335.97)
+ |
+
+
+
+PLX4720
+ |
+
+
+
+1
+ |
+
+RXRG (1)
+ |
+
+RXRG (1.16)
+ |
+
+RXRG (1)
+ |
+
+RXRG (1)
+ |
+
+RXRG (1.03)
+ |
+
+
+
+2
+ |
+
+MMP8 (5.19)
+ |
+
+MMP8 (3.97)
+ |
+
+MMP8 (5.16)
+ |
+
+MMP8 (5.56)
+ |
+
+MMP8 (4.88)
+ |
+
+
+
+3
+ |
+
+RP11-164J13.1 (6.28)
+ |
+
+RP11-599J14.2 (7.22)
+ |
+
+MYO5A (7.09)
+ |
+
+MYO5A (6.81)
+ |
+
+MYO5A (8.75)
+ |
+
+
+
+4
+ |
+
+RP11-599J14.2 (7)
+ |
+
+AP1S2 (8.47)
+ |
+
+LYST (7.97)
+ |
+
+LYST (11.31)
+ |
+
+AP1S2 (10.53)
+ |
+
+
+
+5
+ |
+
+RP4-718J7.4 (7.84)
+ |
+
+LYST (9.34)
+ |
+
+RP11-599J14.2 (8.41)
+ |
+
+RP4-718J7.4 (13.22)
+ |
+
+RP11-599J14.2 (12.69)
+ |
+
+
+
+Paclitaxel
+ |
+
+
+
+1
+ |
+
+MMP24 (1.09)
+ |
+
+MMP24 (1.28)
+ |
+
+MMP24 (1.22)
+ |
+
+MMP24 (2.38)
+ |
+
+SH3BP1 (10.38)
+ |
+
+
+
+2
+ |
+
+AGAP2 (3.16)
+ |
+
+SH3BP1 (2.75)
+ |
+
+AGAP2 (3.44)
+ |
+
+SH3BP1 (2.88)
+ |
+
+PRODH (60.41)
+ |
+
+
+
+3
+ |
+
+SH3BP1 (3.5)
+ |
+
+AGAP2 (4.06)
+ |
+
+SH3BP1 (3.78)
+ |
+
+SLC38A5 (3.84)
+ |
+
+AGAP2 (157.41)
+ |
+
+
+
+4
+ |
+
+SLC38A5 (4.34)
+ |
+
+SLC38A5 (4.22)
+ |
+
+PTTG1IP (3.91)
+ |
+
+AGAP2 (5.88)
+ |
+
+SLC38A5 (179.34)
+ |
+
+
+
+5
+ |
+
+PTTG1IP (4.72)
+ |
+
+PTTG1IP (4.34)
+ |
+
+SLC38A5 (4.31)
+ |
+
+PTTG1IP (6.97)
+ |
+
+MMP24 (308.66)
+ |
+
+
+
+Panobinostat
+ |
+
+
+
+1
+ |
+
+AGAP2 (1.12)
+ |
+
+AGAP2 (1.56)
+ |
+
+AGAP2 (1.81)
+ |
+
+AGAP2 (2.16)
+ |
+
+CYR61 (2.56)
+ |
+
+
+
+2
+ |
+
+CYR61 (2.44)
+ |
+
+CYR61 (2.09)
+ |
+
+CYR61 (2.16)
+ |
+
+CYR61 (2.78)
+ |
+
+AGAP2 (3.25)
+ |
+
+
+
+3
+ |
+
+RPL39P5 (4.19)
+ |
+
+RPL39P5 (4.41)
+ |
+
+RPL39P5 (3.75)
+ |
+
+RPL39P5 (3.88)
+ |
+
+RPL39P5 (152.44)
+ |
+
+
+
+4
+ |
+
+WWTR1 (5.16)
+ |
+
+WWTR1 (5.78)
+ |
+
+WWTR1 (5.94)
+ |
+
+WWTR1 (6.53)
+ |
+
+S100A2 (316.66)
+ |
+
+
+
+5
+ |
+
+MYOF (6.56)
+ |
+
+MYOF (6.16)
+ |
+
+IKZF1 (12.41)
+ |
+
+IKZF1 (9.72)
+ |
+
+MYOF (366.28)
+ |
+
+
+
+RAF265
+ |
+
+
+
+1
+ |
+
+CMTM3 (1.34)
+ |
+
+CMTM3 (1.5)
+ |
+
+CMTM3 (1.69)
+ |
+
+SH2B3 (34)
+ |
+
+CMTM3 (7.06)
+ |
+
+
+
+2
+ |
+
+SYT17 (5.69)
+ |
+
+SYT17 (5.66)
+ |
+
+SYT17 (7.69)
+ |
+
+CMTM3 (155.25)
+ |
+
+SH2B3 (470.66)
+ |
+
+
+
+3
+ |
+
+SH2B3 (6.03)
+ |
+
+SH2B3 (8.91)
+ |
+
+SH2B3 (17.5)
+ |
+
+SLC29A3 (159)
+ |
+
+SYT17 (652.84)
+ |
+
+
+
+4
+ |
+
+EMILIN2 (11.94)
+ |
+
+SLC29A3 (11.84)
+ |
+
+STAT5A (19.66)
+ |
+
+PRKCQ (235)
+ |
+
+RGS16 (713.47)
+ |
+
+
+
+5
+ |
+
+STAT5A (12.47)
+ |
+
+NA (19.91)
+ |
+
+SLC29A3 (22.22)
+ |
+
+LCP2 (259.7)
+ |
+
+AC007620.3 (1087.11)
+ |
+
+
+
+Sorafenib
+ |
+
+
+
+1
+ |
+
+PXK (4.47)
+ |
+
+TP63 (6.72)
+ |
+
+PXK (4.12)
+ |
+
+FAM212A (6.41)
+ |
+
+PXK (9.34)
+ |
+
+
+
+2
+ |
+
+P2RX1 (4.62)
+ |
+
+P2RX1 (7.78)
+ |
+
+FAM212A (5.75)
+ |
+
+P2RX1 (8)
+ |
+
+ARHGAP9 (34.25)
+ |
+
+
+
+3
+ |
+
+FAM212A (4.69)
+ |
+
+PXK (8.88)
+ |
+
+STAC3 (7.16)
+ |
+
+SEC31B (41.69)
+ |
+
+P2RX1 (156.91)
+ |
+
+
+
+4
+ |
+
+STAC3 (5.16)
+ |
+
+FAM212A (16.97)
+ |
+
+P2RX1 (7.25)
+ |
+
+ARHGAP9 (43.19)
+ |
+
+FAM212A (187.31)
+ |
+
+
+
+5
+ |
+
+ARHGAP9 (7.91)
+ |
+
+STAC3 (20.16)
+ |
+
+TP63 (40.97)
+ |
+
+CXCL8 (57.72)
+ |
+
+SEC31B (222.41)
+ |
+
+
+
+TAE684
+ |
+
+
+
+1
+ |
+
+SELPLG (1.09)
+ |
+
+SELPLG (1.06)
+ |
+
+SELPLG (1.12)
+ |
+
+SELPLG (1.03)
+ |
+
+SELPLG (1.34)
+ |
+
+
+
+2
+ |
+
+IL6R (3.19)
+ |
+
+ARID3A (8.12)
+ |
+
+IL6R (3.34)
+ |
+
+IL6R (6.41)
+ |
+
+ARID3A (18.31)
+ |
+
+
+
+3
+ |
+
+NFIL3 (6.34)
+ |
+
+GALNT18 (8.62)
+ |
+
+NFIL3 (6.25)
+ |
+
+NFIL3 (8.06)
+ |
+
+FMNL1 (25.25)
+ |
+
+
+
+4
+ |
+
+ARID3A (7)
+ |
+
+IL6R (10.16)
+ |
+
+RRAS2 (10.19)
+ |
+
+ARID3A (16.38)
+ |
+
+RP11-334A14.2 (144.34)
+ |
+
+
+
+5
+ |
+
+RRAS2 (8.66)
+ |
+
+PPP2R3A (15.69)
+ |
+
+FMNL1 (10.19)
+ |
+
+RRAS2 (17.97)
+ |
+
+PPP2R3A (165.84)
+ |
+
+
+
+TKI258
+ |
+
+
+
+1
+ |
+
+TWF1 (2.44)
+ |
+
+TWF1 (3.31)
+ |
+
+TWF1 (2.41)
+ |
+
+TWF1 (2.28)
+ |
+
+LAPTM5 (20.84)
+ |
+
+
+
+2
+ |
+
+SLC43A1 (2.88)
+ |
+
+GPR162 (3.75)
+ |
+
+PRTN3 (3.31)
+ |
+
+LAPTM5 (5.34)
+ |
+
+SLC43A1 (156.12)
+ |
+
+
+
+3
+ |
+
+PRTN3 (5.25)
+ |
+
+SLC43A1 (6.84)
+ |
+
+SLC43A1 (6.06)
+ |
+
+PRTN3 (6.12)
+ |
+
+TWF1 (163.78)
+ |
+
+
+
+4
+ |
+
+LAPTM5 (5.78)
+ |
+
+TTC28 (8.94)
+ |
+
+LAT2 (6.62)
+ |
+
+SLC43A1 (14.91)
+ |
+
+GPR162 (169.66)
+ |
+
+
+
+5
+ |
+
+LAT2 (5.94)
+ |
+
+LAPTM5 (11.53)
+ |
+
+LAPTM5 (8.19)
+ |
+
+LAT2 (17.41)
+ |
+
+LYL1 (387.97)
+ |
+
+
+
+Topotecan
+ |
+
+
+
+1
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+SLFN11 (1)
+ |
+
+
+
+2
+ |
+
+HSPB8 (2.16)
+ |
+
+HSPB8 (2.28)
+ |
+
+HSPB8 (2.84)
+ |
+
+HSPB8 (2.06)
+ |
+
+HSPB8 (2.62)
+ |
+
+
+
+3
+ |
+
+PPIC (5.69)
+ |
+
+OSGIN1 (5.28)
+ |
+
+OSGIN1 (5.31)
+ |
+
+PPIC (7.81)
+ |
+
+OSGIN1 (7.19)
+ |
+
+
+
+4
+ |
+
+OSGIN1 (5.88)
+ |
+
+AGAP2 (8.81)
+ |
+
+PPIC (6.81)
+ |
+
+RP11-359P5.1 (10.75)
+ |
+
+AGAP2 (15.25)
+ |
+
+
+
+5
+ |
+
+AGAP2 (6.53)
+ |
+
+PPIC (8.91)
+ |
+
+RP11-359P5.1 (8.25)
+ |
+
+CORO1A (12.06)
+ |
+
+HMGB2 (21.53)
+ |
+
+
+
+ZD-6474
+ |
+
+
+
+1
+ |
+
+MAP3K12 (1)
+ |
+
+MAP3K12 (1)
+ |
+
+MAP3K12 (1)
+ |
+
+MAP3K12 (1.09)
+ |
+
+MAP3K12 (1)
+ |
+
+
+
+2
+ |
+
+PIM1 (5.47)
+ |
+
+CTSH (21.88)
+ |
+
+PIM1 (10.28)
+ |
+
+PIM1 (31.44)
+ |
+
+SCD5 (339.58)
+ |
+
+
+
+3
+ |
+
+PRKCQ (9.03)
+ |
+
+TIMP1 (26.31)
+ |
+
+PRKCQ (15.28)
+ |
+
+DYNLT3 (192.12)
+ |
+
+ITGA10 (527.41)
+ |
+
+
+
+4
+ |
+
+CTSH (12.91)
+ |
+
+PRKCQ (26.47)
+ |
+
+CTSH (18.16)
+ |
+
+TIMP1 (211.25)
+ |
+
+TIMP1 (569.81)
+ |
+
+
+
+5
+ |
+
+ITGA10 (20.81)
+ |
+
+PIM1 (33.81)
+ |
+
+ANXA5 (78.78)
+ |
+
+EPHA1 (320.22)
+ |
+
+CTSH (631.5)
+ |
+
+
+
+
+
+
17-AAG
+
+
+
+
+
AEW541
+
+
+
+
+
AZD0530
+
+
+
+
+
AZD6244
+
+
+
+
+
Erlotinib
+
+
+
+
+
Irinotecan
+
+
+
+
+
L-685458
+
+
+
+
+
LBW242
+
+
+
+
+
Lapatinib
+
+
+
+
+
Nilotinib
+
+
+
+
+
Nutlin-3
+
+
+
+
+
PD-0325901
+
+
+
+
+
PD-0332991
+
+
+
+
+
PF2341066
+
+
+
+
+
PHA-665752
+
+
+
+
+
PLX4720
+
+
+
+
+
Paclitaxel
+
+
+
+
+
Panobinostat
+
+
+
+
+
RAF265
+
+
+
+
+
Sorafenib
+
+
+
+
+
TAE684
+
+
+
+
+
TKI258
+
+
+
+
+
Topotecan
+
+
+
+
+
ZD-6474
+
+
+
+
+
+
TCGA BRCA Case Study
+
+
Non-zero Stability Scores Per Method
+
+
+
+
Stability of Top Features
+
+
+
+
Table of Top Genes
+
+
+
+
+
+
+Rank
+ |
+
+MDI+ (ridge)
+ |
+
+MDI+ (logistic)
+ |
+
+MDA
+ |
+
+TreeSHAP
+ |
+
+MDI
+ |
+
+
+
+
+
+1
+ |
+
+ESR1 (1.91)
+ |
+
+ESR1 (1.91)
+ |
+
+ESR1 (4.5)
+ |
+
+ESR1 (7.62)
+ |
+
+ESR1 (13.91)
+ |
+
+
+
+2
+ |
+
+FOXA1 (4.25)
+ |
+
+GATA3 (4.5)
+ |
+
+GATA3 (6.38)
+ |
+
+TPX2 (10.41)
+ |
+
+TPX2 (15.34)
+ |
+
+
+
+3
+ |
+
+FOXC1 (6.12)
+ |
+
+FOXA1 (5.09)
+ |
+
+FOXA1 (8.11)
+ |
+
+GATA3 (19.62)
+ |
+
+FOXM1 (22.84)
+ |
+
+
+
+4
+ |
+
+GATA3 (6.97)
+ |
+
+TPX2 (6.81)
+ |
+
+TPX2 (10.12)
+ |
+
+FOXM1 (20.06)
+ |
+
+MLPH (24.97)
+ |
+
+
+
+5
+ |
+
+AGR3 (7.94)
+ |
+
+AGR3 (10.22)
+ |
+
+MLPH (10.16)
+ |
+
+FOXA1 (20.72)
+ |
+
+FOXA1 (25.66)
+ |
+
+
+
+6
+ |
+
+MLPH (8.16)
+ |
+
+FOXC1 (12.94)
+ |
+
+AGR3 (12.94)
+ |
+
+CDK1 (22.38)
+ |
+
+GATA3 (30.44)
+ |
+
+
+
+7
+ |
+
+TPX2 (11.03)
+ |
+
+MLPH (15.69)
+ |
+
+TBC1D9 (14.22)
+ |
+
+MLPH (22.53)
+ |
+
+CDK1 (31.41)
+ |
+
+
+
+8
+ |
+
+TBC1D9 (14.44)
+ |
+
+FOXM1 (18.12)
+ |
+
+FOXC1 (15.09)
+ |
+
+AGR3 (25.88)
+ |
+
+THSD4 (34.69)
+ |
+
+
+
+9
+ |
+
+FOXM1 (18.66)
+ |
+
+TBC1D9 (21.03)
+ |
+
+FOXM1 (19.88)
+ |
+
+PLK1 (28.47)
+ |
+
+FOXC1 (35)
+ |
+
+
+
+10
+ |
+
+THSD4 (21.78)
+ |
+
+THSD4 (23.66)
+ |
+
+THSD4 (21.28)
+ |
+
+TBC1D9 (29.84)
+ |
+
+TBC1D9 (35.44)
+ |
+
+
+
+11
+ |
+
+SPDEF (25.81)
+ |
+
+CDK1 (24.44)
+ |
+
+CDK1 (21.77)
+ |
+
+FOXC1 (30.44)
+ |
+
+AGR3 (36.44)
+ |
+
+
+
+12
+ |
+
+CA12 (29.44)
+ |
+
+MYBL2 (25.34)
+ |
+
+XBP1 (24.88)
+ |
+
+MYBL2 (33.06)
+ |
+
+PLK1 (36.81)
+ |
+
+
+
+13
+ |
+
+CDK1 (36.09)
+ |
+
+RACGAP1 (26.81)
+ |
+
+KIF2C (27.95)
+ |
+
+THSD4 (33.53)
+ |
+
+MYBL2 (45.22)
+ |
+
+
+
+14
+ |
+
+GABRP (36.72)
+ |
+
+ASPM (27.56)
+ |
+
+PLK1 (28.58)
+ |
+
+KIF2C (35.56)
+ |
+
+KIF2C (46.22)
+ |
+
+
+
+15
+ |
+
+PLK1 (37.97)
+ |
+
+PLK1 (28.84)
+ |
+
+MYBL2 (30.97)
+ |
+
+ASPM (40.59)
+ |
+
+ASPM (49.72)
+ |
+
+
+
+16
+ |
+
+FAM171A1 (38.59)
+ |
+
+UBE2C (30.41)
+ |
+
+GABRP (31.94)
+ |
+
+GMPS (41.41)
+ |
+
+GMPS (52.03)
+ |
+
+
+
+17
+ |
+
+ASPM (39.5)
+ |
+
+SPAG5 (31.81)
+ |
+
+ASPM (35.97)
+ |
+
+XBP1 (50.44)
+ |
+
+SPDEF (58.88)
+ |
+
+
+
+18
+ |
+
+SFRP1 (40.25)
+ |
+
+GMPS (35.34)
+ |
+
+FAM171A1 (38.55)
+ |
+
+CENPF (56.69)
+ |
+
+FAM171A1 (68.56)
+ |
+
+
+
+19
+ |
+
+XBP1 (40.44)
+ |
+
+KIF2C (36.03)
+ |
+
+CA12 (40.16)
+ |
+
+MKI67 (59.69)
+ |
+
+CA12 (69.59)
+ |
+
+
+
+20
+ |
+
+MYBL2 (43.12)
+ |
+
+CA12 (37.31)
+ |
+
+CDC20 (48.45)
+ |
+
+CA12 (59.91)
+ |
+
+UBE2C (69.78)
+ |
+
+
+
+21
+ |
+
+TFF3 (43.12)
+ |
+
+RRM2 (46.09)
+ |
+
+SPDEF (53.05)
+ |
+
+RACGAP1 (60.53)
+ |
+
+TFF3 (70)
+ |
+
+
+
+22
+ |
+
+KIF2C (44.62)
+ |
+
+CENPF (46.31)
+ |
+
+UBE2C (54.27)
+ |
+
+SPAG5 (61.94)
+ |
+
+CENPF (70.31)
+ |
+
+
+
+23
+ |
+
+PRR15 (44.88)
+ |
+
+GABRP (47.59)
+ |
+
+ANXA9 (55.41)
+ |
+
+FAM171A1 (63)
+ |
+
+RACGAP1 (70.94)
+ |
+
+
+
+24
+ |
+
+AGR2 (45)
+ |
+
+SFRP1 (49.34)
+ |
+
+C1orf64 (57.22)
+ |
+
+KIF11 (63.56)
+ |
+
+SPAG5 (71.81)
+ |
+
+
+
+25
+ |
+
+MIA (45.31)
+ |
+
+XBP1 (49.78)
+ |
+
+RACGAP1 (57.97)
+ |
+
+ANLN (64.28)
+ |
+
+KIF11 (73.12)
+ |
+
+
+
+
+
+
Prediction using Top Features
+
+
+
+
+
Misspecified Models
+
+
linear
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
lss_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
linear_lss_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
+
Varying Sparsity
+
+
linear
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
lss_3m_2r
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
hier_poly_3m_2r
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
linear_lss_3m_2r
+
+
juvenile
+
+
+
+
splicing
+
+
+
+
+
+
Varying # Features
+
+
linear
+
+
ccle_rnaseq
+
+
+
+
+
lss_3m_2r
+
+
ccle_rnaseq
+
+
+
+
+
hier_poly_3m_2r
+
+
ccle_rnaseq
+
+
+
+
+
linear_lss_3m_2r
+
+
ccle_rnaseq
+
+
+
+
+
+
Prediction Results
+
+
+
+
MDI+ Modeling Choices
+
+
linear
+
+
enhancer
+
+
min_samples_per_leaf = 5
+
+
+
+
min_samples_per_leaf = 1
+
+
+
+
+
ccle_rnaseq
+
+
min_samples_per_leaf = 5
+
+
+
+
min_samples_per_leaf = 1
+
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
min_samples_per_leaf = 5
+
+
+
+
min_samples_per_leaf = 1
+
+
+
+
+
ccle_rnaseq
+
+
min_samples_per_leaf = 5
+
+
+
+
min_samples_per_leaf = 1
+
+
+
+
+
+
+
MDI+ GLM/Metric Choices
+
+
Held-out Test Prediction Scores
+
+
linear
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
+
+
ccle_rnaseq
+
+
+
+
+
+
Stability Scores - Regression
+
+
linear
+
+
enhancer
+
+
+
+
+
ccle_rnaseq
+
+
+
+
+
+
hier_poly_3m_2r
+
+
enhancer
+
+
+
+
+
ccle_rnaseq
+
+
+
+
+
+
+
Stability Scores - Classification
+
+
logistic
+
+
juvenile
+
+
+
+
+
splicing
+
+
+
+
+
+
hier_poly_3m_2r_logistic
+
+
juvenile
+
+
+
+
+
splicing
+
+
+
+
+
+
+
+
+
+
+
+
+
+
---
title: "MDI+ Simulation Results Summary"
author: ""
date: "`r format(Sys.time(), '%B %d, %Y')`"
output: vthemes::vmodern
params:
  results_dir: 
    label: "Results directory"
    value: "/global/scratch/users/tiffanytang/feature_importance/results_final/"
  seed:
    label: "Seed"
    value: 12345
  for_paper:
    label: "Export plots for paper"
    value: FALSE
  use_cached:
    label: "Use cached .rds files"
    value: TRUE
  interactive:
    label: "Interactive plots"
    value: FALSE
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)

library(magrittr)
library(patchwork)
chunk_idx <- 1

# set parameters
results_dir <- params$results_dir
seed <- params$seed
tables_dir <- file.path("tables")
figures_dir <- file.path("figures")
figures_subdirs <- c("regression_sims", "classification_sims", "robust_sims",
                     "misspecified_regression_sims", "varying_sparsity",
                     "varying_p", "modeling_choices", "glm_metric_choices")
for (figures_subdir in figures_subdirs) {
  if (!dir.exists(file.path(figures_dir, figures_subdir))) {
    dir.create(file.path(figures_dir, figures_subdir), recursive = TRUE)
  }
}

# miscellaneous helper variables
heritabilities <- c(0.1, 0.2, 0.4, 0.8)
frac_label_corruptions <- c("0.25", "0.15", "0.05", "0")
corrupt_sizes <- c(0.05, 0.025, 0.01, 0)
mean_shifts <- c(10, 25)
metric <- "rocauc"

# plot options
point_size <- 2
line_size <- 1
errbar_width <- 0
if (params$interactive) {
  plot_fun <- plotly::ggplotly
} else {
  plot_fun <- function(x) x
}

manual_color_palette_choices <- c(
  "black", "black", "#9B5DFF", "blue",
  "orange", "#71beb7", "#218a1e", "#cc3399"
)
show_methods_choices <- c(
  "GMDI_ridge_RF", "GMDI_ridge_loo_r2_RF", "GMDI_logistic_logloss_RF", "GMDI_Huber_loo_huber_loss_RF",
  "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF"
)
method_labels_choices <- c(
  "MDI+ (ridge)", "MDI+ (ridge)", "MDI+ (logistic)", "MDI+ (Huber)",
  "MDA", "MDI", "MDI-oob", "TreeSHAP"
)
color_df <- tibble::tibble(
  color = manual_color_palette_choices,
  name = show_methods_choices,
  label = method_labels_choices
)

manual_color_palette_all <- NULL
show_methods_all <- NULL
method_labels_all <- ggplot2::waiver()

custom_theme <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
  axis.text = ggplot2::element_text(size = 14),
  axis.title = ggplot2::element_text(size = 20, face = "plain"),
  legend.text = ggplot2::element_text(size = 14),
  plot.title = ggplot2::element_blank()
  # plot.title = ggplot2::element_text(size = 12, face = "plain", hjust = 0.5)
)
custom_theme_with_legend <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.title = ggplot2::element_text(size = 12, face = "plain"),
  legend.text = ggplot2::element_text(size = 9),
  legend.text.align = 0,
  plot.title = ggplot2::element_blank()
)
custom_theme_without_legend <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.title = ggplot2::element_text(size = 12, face = "plain"),
  legend.title = ggplot2::element_blank(),
  legend.text = ggplot2::element_text(size = 9),
  legend.text.align = 0,
  plot.title = ggplot2::element_blank()
)

fig_height <- 6
fig_width <- 10

source("../../scripts/viz.R", chdir = TRUE)
```


# Regression Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.35)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- y_models

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "regression_sims",
                  sprintf("regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```

```{r eval = TRUE, results = "asis"}
y_models <- c("lss_3m_2r", "hier_poly_3m_2r")
x_models <- c("splicing")

remove_x_axis_models <- c("lss_3m_2r")
keep_legend_x_models <- x_models
keep_legend_y_models <- c("lss_3m_2r")

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "regression_sims",
                  sprintf("main_regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    }
  }
}

```


# Classification Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "GMDI_logistic_logloss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
legend_position <- c(0.73, 0.4)

vary_param_name <- "frac_label_corruption_sample_row_n"
y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- y_models

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in frac_label_corruptions) {
        plt <- results %>%
          dplyr::filter(frac_label_corruption_name == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != frac_label_corruptions[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          height <- 2.72
        } else {
          height <- 3
        }
        if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "classification_sims",
                  sprintf("classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "frac_label_corruption_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}
```

```{r eval = TRUE, results = "asis"}
y_models <- c("logistic", "linear_lss_3m_2r_logistic")
x_models <- c("ccle_rnaseq")

remove_x_axis_models <- c("logistic")
keep_legend_x_models <- x_models
keep_legend_y_models <- c("logistic")

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      for (h in frac_label_corruptions) {
        plt <- results %>%
          dplyr::filter(frac_label_corruption_name == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != frac_label_corruptions[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          height <- 2.72
        } else {
          height <- 3
        }
        if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "classification_sims",
                  sprintf("main_classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    }
  }
}
```


# Robust Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_loo_r2_RF", "GMDI_Huber_loo_huber_loss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
legend_position <- c(0.73, 0.4)

vary_param_name <- "corrupt_size_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")

remove_x_axis_models <- c(10)
keep_legend_x_models <- c("enhancer", "ccle_rnaseq")
keep_legend_y_models <- c("linear", "linear_lss_3m_2r")
keep_legend_mean_shifts <- c(10)

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle} \n\n", x_model))
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    for (mean_shift in mean_shifts) {
      cat(sprintf("\n\n#### Mean Shift = %s \n\n", mean_shift))
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
      fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC",
        TRUE ~ metric
      )
      if (params$for_paper) {
        tmp <- results %>%
          dplyr::filter(corrupt_size_name %in% corrupt_sizes,
                        method %in% show_methods) %>%
          dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
          dplyr::summarise(mean = mean(.data[[metric]]))
        min_y <- min(tmp$mean)
        max_y <- max(tmp$mean)
        for (h in corrupt_sizes) {
          plt <- results %>%
            dplyr::filter(corrupt_size_name == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            )
          if (h != corrupt_sizes[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (mean_shift %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == corrupt_sizes[length(corrupt_sizes)]) & 
                (mean_shift %in% keep_legend_mean_shifts) &
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "robust_sims",
                    sprintf("robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = "corrupt_size_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```

```{r eval = TRUE, results = "asis"}
y_models <- c("lss_3m_2r")
x_models <- c("enhancer")

remove_x_axis_models <- c(10)
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_mean_shifts <- c(10)

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      for (mean_shift in mean_shifts) {
        plt_ls <- list()
        sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
        fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
                           paste0("varying_", vary_param_name), 
                           paste0("seed", seed), "results")
        if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
          results <- readRDS(sprintf("%s.rds", fname))
          if (length(setdiff(show_methods, unique(results$method))) > 0) {
            results <- data.table::fread(sprintf("%s.csv", fname)) %>%
              reformat_results()
            saveRDS(results, sprintf("%s.rds", fname))
          }
        } else {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
        metric_name <- dplyr::case_when(
          metric == "rocauc" ~ "AUROC",
          metric == "prauc" ~ "PRAUC",
          TRUE ~ metric
        )
        tmp <- results %>%
          dplyr::filter(corrupt_size_name %in% corrupt_sizes,
                        method %in% show_methods) %>%
          dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
          dplyr::summarise(mean = mean(.data[[metric]]))
        min_y <- min(tmp$mean)
        max_y <- max(tmp$mean)
        for (h in corrupt_sizes) {
          plt <- results %>%
            dplyr::filter(corrupt_size_name == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            )
          if (h != corrupt_sizes[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (mean_shift %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == corrupt_sizes[length(corrupt_sizes)]) & 
                (mean_shift %in% keep_legend_mean_shifts) &
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "robust_sims",
                    sprintf("main_robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      }
    }
  }
}

```


# Correlation Bias Simulations {.tabset .tabset-vmodern}

## Main Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "#218a1e", "orange", "#71beb7", "#cc3399")
show_methods <- c("GMDI", "MDI-oob", "MDA", "MDI",  "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDI-oob", "MDA", "MDI", "TreeSHAP")

custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
keep_heritabilities <- c(0.1, 0.4)
sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_rho",
  sprintf("seed%s", seed),
  "results.csv"
)
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities)

n <- 250
p <- 100
sig_ids <- 0:5
cnsig_ids <- 6:49
nreps <- length(unique(results$rep))

#### Examine average rank across varying correlation levels ####
plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "rho_name",
  param_name = NULL,
  sig_ids = sig_ids,
  cnsig_ids = cnsig_ids,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  cat(sprintf("\n\n### PVE = %s\n\n", h))
  plt_df <- plt_base$agg[[idx]]$data
  plt_ls <- list()
  for (method in method_labels) {
    plt_ls[[method]] <- plt_df %>%
      dplyr::filter(fi == method) %>%
      ggplot2::ggplot() +
      ggplot2::aes(x = rho_name, y = .mean, color = group) +
      ggplot2::geom_line() +
      ggplot2::geom_ribbon(
        ggplot2::aes(x = rho_name, 
                     ymin = .mean - (.sd / sqrt(nreps)), 
                     ymax = .mean + (.sd / sqrt(nreps)), 
                     fill = group), 
        inherit.aes = FALSE, alpha = 0.2
      ) +
      ggplot2::labs(
        x = expression("Correlation ("*rho*")"), 
        y = "Average Rank", 
        color = "Feature\nGroup", 
        title = method
      ) +
      ggplot2::coord_cartesian(
        ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
      ) +
      ggplot2::scale_color_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::scale_fill_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::guides(fill = "none", linetype = "none") +
      vthemes::theme_vmodern(
        size_preset = "medium", bg_color = "white", grid_color = "white",
        axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
        axis.text = ggplot2::element_text(size = 14),
        axis.title = ggplot2::element_text(size = 18, face = "plain"),
        legend.text = ggplot2::element_text(size = 14),
        legend.title = ggplot2::element_text(size = 18),
        plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
      )
    if (method != method_labels[1]) {
      plt_ls[[method]] <- plt_ls[[method]] +
        ggplot2::theme(axis.title.y = ggplot2::element_blank(),
                       axis.ticks.y = ggplot2::element_blank(),
                       axis.text.y = ggplot2::element_blank())
    }
  }
  
  plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
  if (params$for_paper) {
    ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide.pdf", h)),
                    # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                    height = fig_height * .55, width = fig_width * .3 * length(show_methods))
  } else {
    vthemes::subchunkify(
      plt_wide, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
    )
    chunk_idx <- chunk_idx + 1
  }
}


#### Examine number of splits across varying correlation levels ####
cat("\n\n### Number of RF Splits \n\n")
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities,
                fi == "MDI_with_splits")

total_splits <- results %>%
  dplyr::group_by(rho_name, heritability, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

plt_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("rho_name", "heritability", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100,
    group = dplyr::case_when(
      var %in% sig_ids ~ "Sig",
      var %in% cnsig_ids ~ "C-NSig",
      TRUE ~ "NSig"
    )
  ) %>%
  dplyr::mutate(
    group = factor(group, levels = c("Sig", "C-NSig", "NSig"))
  ) %>%
  dplyr::group_by(rho_name, heritability, group) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

min_y <- min(plt_df$.mean)
max_y <- max(plt_df$.mean)

plt_ls <- list()
for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  plt <- plt_df %>%
    dplyr::filter(heritability == !!h) %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = rho_name, y = .mean, color = group) +
    ggplot2::geom_line(size = 1) +
    # ggplot2::geom_ribbon(
    #   ggplot2::aes(x = rho_name, 
    #                ymin = .mean - (.sd / sqrt(nreps)), 
    #                ymax = .mean + (.sd / sqrt(nreps)), 
    #                fill = group), 
    #   inherit.aes = FALSE, alpha = 0.2
    # ) +
    ggplot2::labs(
      x = expression("Correlation ("*rho*")"), 
      y = "Percentage of Splits in RF (%)",
      color = "Feature\nGroup", 
      title = sprintf("PVE = %s", h)
    ) +
    ggplot2::coord_cartesian(ylim = c(min_y, max_y)) +
    ggplot2::scale_color_manual(
      values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
    ) +
    ggplot2::scale_fill_manual(
      values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
    ) +
    ggplot2::guides(fill = "none", linetype = "none") +
    vthemes::theme_vmodern(
      size_preset = "medium", bg_color = "white", grid_color = "white",
      axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
      axis.text = ggplot2::element_text(size = 14),
      axis.title = ggplot2::element_text(size = 18, face = "plain"),
      legend.text = ggplot2::element_text(size = 14),
      legend.title = ggplot2::element_text(size = 18),
      plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
    )
  if (idx != 1) {
    plt <- plt +
      ggplot2::theme(axis.title.y = ggplot2::element_blank())
  }
  plt_ls[[idx]] <- plt
}

plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
if (params$for_paper) {
  ggplot2::ggsave(plt_wide & ggplot2::theme(plot.title = ggplot2::element_blank()), 
                  filename = file.path(figures_dir, "correlation_sim_num_splits.pdf"),
                  width = fig_width * 1, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```

## Appendix Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "gray")
show_methods <- c("GMDI", "GMDI_inbag")
method_labels <- c("MDI+ (LOO)", "MDI+ (in-bag)")

custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
keep_heritabilities <- c(0.1, 0.4)
sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_rho",
  sprintf("seed%s", seed),
  "results.csv"
)
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities)

n <- 250
p <- 100
sig_ids <- 0:5
cnsig_ids <- 6:49
nreps <- length(unique(results$rep))

#### Examine average rank across varying correlation levels ####
plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "rho_name",
  param_name = NULL,
  sig_ids = sig_ids,
  cnsig_ids = cnsig_ids,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  cat(sprintf("\n\n### PVE = %s\n\n", h))
  plt_df <- plt_base$agg[[idx]]$data
  plt_ls <- list()
  for (method in method_labels) {
    plt_ls[[method]] <- plt_df %>%
      dplyr::filter(fi == method) %>%
      ggplot2::ggplot() +
      ggplot2::aes(x = rho_name, y = .mean, color = group) +
      ggplot2::geom_line() +
      ggplot2::geom_ribbon(
        ggplot2::aes(x = rho_name, 
                     ymin = .mean - (.sd / sqrt(nreps)), 
                     ymax = .mean + (.sd / sqrt(nreps)), 
                     fill = group), 
        inherit.aes = FALSE, alpha = 0.2
      ) +
      ggplot2::labs(
        x = expression("Correlation ("*rho*")"), 
        y = "Average Rank", 
        color = "Feature\nGroup", 
        title = method
      ) +
      ggplot2::coord_cartesian(
        ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
      ) +
      ggplot2::scale_color_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::scale_fill_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::guides(fill = "none", linetype = "none") +
      vthemes::theme_vmodern(
        size_preset = "medium", bg_color = "white", grid_color = "white",
        axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
        axis.text = ggplot2::element_text(size = 14),
        axis.title = ggplot2::element_text(size = 18, face = "plain"),
        legend.text = ggplot2::element_text(size = 14),
        legend.title = ggplot2::element_text(size = 18),
        plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
      )
    if (method != method_labels[1]) {
      plt_ls[[method]] <- plt_ls[[method]] +
        ggplot2::theme(axis.title.y = ggplot2::element_blank(),
                       axis.ticks.y = ggplot2::element_blank(),
                       axis.text.y = ggplot2::element_blank())
    }
  }
  
  plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
  if (params$for_paper) {
    ggplot2::ggsave(plt_wide, 
                    filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide_appendix.pdf", h)),
                    # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                    height = fig_height * .55, width = fig_width * .37 * length(show_methods))
  } else {
    vthemes::subchunkify(
      plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
    )
    chunk_idx <- chunk_idx + 1
  }
}
```


# Entropy Bias Simulations {.tabset .tabset-vmodern}

## Main Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
#### Entropy Regression Results ####
manual_color_palette <- rev(c("black", "orange", "#71beb7", "#218a1e", "#cc3399"))
show_methods <- rev(c("GMDI", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
method_labels <- rev(c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
alpha_values <- rev(c(1, rep(0.7, length(method_labels) - 1)))
y_limits <- c(1, 4.5)

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Regression Avg Rank ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_regression <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size",
    y = expression("Average Rank of "*X[1]),
    color = "Method",
    title = "Regression"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(ylim = y_limits) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
  )

#### Entropy Regression # Splits ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1, 
                fi == "MDI_with_splits")

nreps <- length(unique(results$rep))

total_splits <- results %>%
  dplyr::group_by(n_name, heritability, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

regression_num_splits_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("n_name", "heritability", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100
  ) %>%
  dplyr::group_by(n_name, heritability, var) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

#### Entropy Classification Results ####
manual_color_palette <- rev(c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399"))
show_methods <- rev(c("GMDI_ridge", "GMDI_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
method_labels <- rev(c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
alpha_values <- rev(c(1, 1, rep(0.7, length(method_labels) - 2)))

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_c_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Classification Avg Rank ####
results <- data.table::fread(fname)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "c_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_classification <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    title = "Classification"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(
    ylim = y_limits
  ) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
  )

#### Entropy Classification # Splits ####
results <- data.table::fread(fname) %>%
  dplyr::filter(fi == "MDI_with_splits")

total_splits <- results %>%
  dplyr::group_by(n_name, c, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

classification_num_splits_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("n_name", "c", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100
  ) %>%
  dplyr::group_by(n_name, c, var) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

#### Show Avg Rank Plot ####
cat("\n\n### Average Rank \n\n")
plt <- plt_regression + plt_classification
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims.pdf"), 
                  width = fig_width * 1.2, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

#### Show # Splits Plot ####
cat("\n\n### Number of RF Splits \n\n")
plt_df <- dplyr::bind_rows(
  Regression = regression_num_splits_df, 
  Classification = classification_num_splits_df, 
  .id = "type"
) %>%
  dplyr::mutate(
    var = dplyr::case_when(
      var == 0 ~ "Bernoulli (Signal)",
      var == 1 ~ "Normal (Non-Signal)",
      var == 2 ~ "4-Categories (Non-Signal)",
      var == 3 ~ "10-Categories (Non-Signal)",
      var == 4 ~ "20-Categories (Non-Signal)"
    ) %>%
      factor(levels = c("Normal (Non-Signal)",
                        "20-Categories (Non-Signal)",
                        "10-Categories (Non-Signal)",
                        "4-Categories (Non-Signal)",
                        "Bernoulli (Signal)"))
  )

sim_types <- unique(plt_df$type)
plt_ls <- list()
for (idx in 1:length(sim_types)) {
  sim_type <- sim_types[idx]
  plt <- plt_df %>%
    dplyr::filter(type == !!sim_type) %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = n_name, y = .mean, color = as.factor(var)) +
    ggplot2::geom_line(size = 1) +
    ggplot2::labs(
      x = "Sample Size",
      y = "Percentage of Splits in RF (%)",
      color = "Feature", 
      title = sim_type
    ) +
    ggplot2::guides(fill = "none", linetype = "none") +
    vthemes::theme_vmodern(
      size_preset = "medium", bg_color = "white", grid_color = "white",
      axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
      axis.text = ggplot2::element_text(size = 14),
      axis.title = ggplot2::element_text(size = 18, face = "plain"),
      legend.text = ggplot2::element_text(size = 14),
      legend.title = ggplot2::element_text(size = 18),
      plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
    )
  if (idx != 1) {
    plt <- plt +
      ggplot2::theme(axis.title.y = ggplot2::element_blank())
  }
  plt_ls[[idx]] <- plt
}

plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
if (params$for_paper) {
  ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, "entropy_sims_num_splits.pdf"),
                  # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                  width = fig_width * 1.1, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```

## Appendix Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
#### Entropy Regression Results ####
manual_color_palette <- rev(c("black", "black", "#71beb7"))
show_methods <- rev(c("GMDI", "GMDI_inbag", "MDI"))
method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI"))
alpha_values <- rev(c(1, 1, rep(1, length(method_labels) - 2)))
linetype_values <- c("solid", "dashed", "solid")
y_limits <- c(1, 5)

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Regression Avg Rank ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_regression <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    linetype = "Method",
    title = "Regression"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_linetype_manual(values = linetype_values, 
                                 labels = method_labels,
                                 guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(ylim = y_limits) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
    legend.key.width = ggplot2::unit(1, "cm")
  )

#### Entropy Classification Results ####
manual_color_palette <- rev(c("black", "black", "#9B5DFF", "#9B5DFF", "#71beb7"))
show_methods <- rev(c("GMDI_ridge", "GMDI_ridge_inbag", "GMDI_logistic_logloss", "GMDI_logistic_logloss_inbag", "MDI"))
method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI+ (logistic, LOO)", "MDI+ (logistic, in-bag)", "MDI"))
alpha_values <- rev(c(1, 1, 1, 1, rep(1, length(method_labels) - 4)))
linetype_values <- c("solid", "dashed", "solid", "dashed", "solid")

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_c_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Classification Avg Rank ####
results <- data.table::fread(fname)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "c_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_classification <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    linetype = "Method",
    title = "Classification"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_linetype_manual(values = linetype_values,
                                 labels = method_labels,
                                 guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(
    ylim = y_limits
  ) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
    legend.key.width = ggplot2::unit(1, "cm")
  )

#### Show Avg Rank Plot ####
cat("\n\n### Average Rank \n\n")
plt <- plt_regression + plt_classification
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims_appendix.pdf"), 
                  width = fig_width * 1.35, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt, i = chunk_idx, fig_height = fig_height * .7, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```


# CCLE Case Study {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
show_methods <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))

fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "results.csv"
)
pred_fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "pred_results.csv"
)
X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
X_columns <- colnames(X)
results <- data.table::fread(fname)
pred_results <- data.table::fread(pred_fname) %>%
  tibble::as_tibble() %>%
  dplyr::mutate(method = fi) %>%
  # tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>%
  tidyr::pivot_wider(names_from = metric, values_from = metric_value)

out <- plot_top_stability(
  results,
  group_id = "y_task",
  top_r = 10,
  show_max_features = 5,
  varnames = colnames(X),
  base_method = "MDI+ (ridge)",
  return_df = TRUE,
  manual_color_palette = rev(manual_color_palette),
  show_methods = rev(show_methods),
  method_labels = rev(method_labels)
)

ranking_df <- out$rankings
stability_df <- out$stability
plt_ls <- out$plot_ls

# get gene names instead of ENSG ids
library(EnsDb.Hsapiens.v79)
varnames <- ensembldb::select(
  EnsDb.Hsapiens.v79, 
  keys = stringr::str_remove(colnames(X), "\\.[^\\.]+$"), 
  keytype = "GENEID", columns = c("GENEID","SYMBOL")
)

# get top 5 genes
top_genes <- ranking_df %>%
  dplyr::group_by(y_task, fi, var) %>%
  dplyr::summarise(mean_rank = mean(rank)) %>%
  dplyr::ungroup() %>%
  dplyr::arrange(y_task, mean_rank) %>%
  dplyr::group_by(y_task, fi) %>%
  dplyr::mutate(rank = 1:dplyr::n()) %>%
  dplyr::ungroup()

top_genes_table <- top_genes %>%
  dplyr::filter(rank <= 50) %>%
  dplyr::left_join(
    y = data.frame(var = 0:(ncol(X) - 1), 
                   GENEID = stringr::str_remove(colnames(X), "\\.[^\\.]+$")),
    by = "var"
  ) %>%
  dplyr::left_join(
    y = varnames, by = "GENEID"
  ) %>%
  dplyr::mutate(
    value = sprintf("%s (%s)", SYMBOL, round(mean_rank, 2))
  ) %>%
  tidyr::pivot_wider(
    id_cols = c(y_task, rank), 
    names_from = fi, 
    values_from = value
  ) %>%
  dplyr::rename(
    "Drug" = "y_task",
    "Rank" = "rank"
  )

# write.csv(
#   top_genes_table, file.path(tables_dir, "ccle_rnaseq_top_genes_table.csv"),
#   row.names = FALSE, quote = FALSE
# )

# get # genes with perfect stability
stable_genes_kable <- stability_df %>%
  dplyr::group_by(fi, y_task) %>%
  dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
  tidyr::pivot_wider(id_cols = fi, names_from = y_task, values_from = stability_1) %>%
  dplyr::ungroup() %>%
  as.data.frame()

pred_plt_ls <- list()
for (plt_id in names(plt_ls)) {
  cat(sprintf("\n\n## %s \n\n", plt_id))
  plt <- plt_ls[[plt_id]]
  if (plt_id == "Summary") {
    plt <- plt + ggplot2::labs(y = "Drug")
    if (params$for_paper) {
      ggplot2::ggsave(plt, filename = file.path(figures_dir, "ccle_rnaseq_stability_top10.pdf"), 
                      width = fig_width * 1.5, height = fig_height * 1.1)
    } else {
      vthemes::subchunkify(
        plt, i = sprintf("%s-%s", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
    }
    
    cat("\n\n## Table of Top Genes \n\n")
    if (params$for_paper) {
      top_genes_kable <- top_genes_table %>%
        dplyr::filter(Rank <= 5) %>%
        dplyr::select(-Drug) %>%
        dplyr::rename(" " = "Rank") %>%
        vthemes::pretty_kable(format = "latex", longtable = TRUE) %>%
        kableExtra::kable_styling(latex_options = "repeat_header") %>%
        # kableExtra::collapse_rows(columns = 1, valign = "middle")
        kableExtra::pack_rows(
          index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
        )
    } else {
      top_genes_kable <- top_genes_table %>%
        dplyr::filter(Rank <= 5) %>%
        dplyr::select(-Drug) %>%
        dplyr::rename(" " = "Rank") %>%
        vthemes::pretty_kable() %>%
        kableExtra::kable_styling(latex_options = "repeat_header") %>%
        # kableExtra::collapse_rows(columns = 1, valign = "middle")
        kableExtra::pack_rows(
          index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
        )
      vthemes::subchunkify(
        top_genes_kable, i = sprintf("%s-table", fpath)
      )
    }
  } else {
    pred_plt_ls[[plt_id]] <- pred_results %>%
      dplyr::filter(k <= 10, y_task == !!plt_id) %>%
      plot_metrics(
        metric = "r2",
        x_str = "k",
        facet_str = NULL,
        point_size = point_size,
        line_size = line_size,
        errbar_width = errbar_width,
        manual_color_palette = manual_color_palette,
        show_methods = show_methods,
        method_labels = method_labels,
        alpha_values = alpha_values,
        inside_legend = FALSE
      ) +
      ggplot2::labs(x = "Number of Top Features (k)", y = "Test R-squared",
                    color = "Method", alpha = "Method", title = plt_id) +
      vthemes::theme_vmodern()
    
    if (!params$for_paper) {
      vthemes::subchunkify(
        plot_fun(plt), i = sprintf("%s-%s", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
      vthemes::subchunkify(
        plot_fun(pred_plt_ls[[plt_id]]), i = sprintf("%s-%s-pred", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
    }
  }
}

if (params$for_paper) {
  all_plt_ls <- list()
  for (idx in 1:length(pred_plt_ls)) {
    drug <- names(pred_plt_ls)[idx]
    plt <- pred_plt_ls[[drug]]
    if ((idx - 1) %% 4 != 0) {
      plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
    }
    if ((idx - 1) %/% 4 != 5) {
      plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
    }
    all_plt_ls[[drug]] <- plt
  }
  plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
    ggplot2::theme(
      plot.title = ggplot2::element_text(hjust = 0.5),
      panel.background = ggplot2::element_rect(fill = "white"),
      panel.grid.major = ggplot2::element_line(color = "white")
    )
  ggplot2::ggsave(plt, 
                  filename = file.path(figures_dir, "ccle_rnaseq_predictions.pdf"),
                  width = fig_width * 1, height = fig_height * 2)
  
  pred10_ranks <- pred_results %>% 
    dplyr::filter(k == 10) %>% 
    dplyr::group_by(fi, y_task) %>% 
    dplyr::summarise(r2 = mean(r2)) %>% 
    tidyr::pivot_wider(names_from = y_task, values_from = r2) %>% 
    dplyr::ungroup() %>% 
    dplyr::mutate(dplyr::across(`17-AAG`:`ZD-6474`, ~rank(-.x)))
  pred10_ranks_table <- apply(pred10_ranks, 1, FUN = function(x) table(x[-1])) %>%
    tibble::as_tibble() %>%
    tibble::rownames_to_column("Rank") %>%
    setNames(c("Rank", pred10_ranks$fi)) %>%
    dplyr::select(Rank, `MDI+`, TreeSHAP, MDI, `MDI-oob`, MDA) %>%
    vthemes::pretty_kable(format = "latex")
  top_genes_kable <- top_genes_table %>%
    dplyr::filter(Rank <= 5) %>%
    dplyr::select(-Drug) %>%
    dplyr::rename(" " = "Rank") %>%
    vthemes::pretty_kable() %>%
    kableExtra::kable_styling(latex_options = "repeat_header") %>%
    # kableExtra::collapse_rows(columns = 1, valign = "middle")
    kableExtra::pack_rows(
      index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
    )
  vthemes::subchunkify(
    top_genes_kable, i = sprintf("%s-table", fpath)
  )
}
```

```{r eval = TRUE, results = "asis"}
if (params$for_paper) {
  manual_color_palette_full <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
  show_methods_full <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  method_labels_full <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  keep_methods_list <- list(
    all = show_methods_full,
    no_mdioob = c("MDI+", "MDA", "MDI", "TreeSHAP"),
    zoom = c("MDI+", "MDI", "TreeSHAP")
  )
  
  fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
  results <- data.table::fread(fname)
  
  for (id in names(keep_methods_list)) {
    keep_methods <- keep_methods_list[[id]]
    manual_color_palette <- manual_color_palette_full[show_methods_full %in% keep_methods]
    show_methods <- show_methods_full[show_methods_full %in% keep_methods]
    method_labels <- method_labels_full[show_methods_full %in% keep_methods]
    
    plt_ls <- plot_top_stability(
      results,
      group_id = "y_task",
      top_r = 10,
      show_max_features = 5,
      varnames = colnames(X),
      base_method = "MDI+",
      return_df = FALSE,
      manual_color_palette = manual_color_palette,
      show_methods = show_methods,
      method_labels = method_labels
    )
    all_plt_ls <- list()
    for (idx in 1:length(unique(results$y_task))) {
      drug <- sort(unique(results$y_task))[idx]
      plt <- plt_ls[[drug]]
      if ((idx - 1) %% 4 != 0) {
        plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
      }
      if ((idx - 1) %/% 4 != 5) {
        plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
      }
      all_plt_ls[[drug]] <- plt
    }
    plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
      ggplot2::theme(
        plot.title = ggplot2::element_text(hjust = 0.5),
        panel.background = ggplot2::element_rect(fill = "white"),
        panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
      )
    ggplot2::ggsave(plt, 
                    filename = file.path(figures_dir, sprintf("ccle_rnaseq_stability_select_features_%s.pdf", id)),
                    width = fig_width * 1, height = fig_height * 2)
  }
}
```


# TCGA BRCA Case Study {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#cc3399")
show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "TreeSHAP")
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))

fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "results.csv"
)
X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_tcga_cleaned.csv")
results <- data.table::fread(fname)
out <- plot_top_stability(
  results,
  group_id = NULL,
  top_r = 10,
  show_max_features = 20,
  varnames = colnames(X),
  base_method = "MDI+ (ridge)", 
  # descending_methods = "MDI+ (logistic)",
  return_df = TRUE,
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

ranking_df <- out$rankings
stability_df <- out$stability
plt_ls <- out$plot_ls

cat("\n\n## Non-zero Stability Scores Per Method\n\n")
vthemes::subchunkify(
  plt_ls[[1]], i = sprintf("%s-%s", fpath, 1), 
  fig_height = fig_height * round(length(unique(results$fi)) / 2), fig_width = fig_width
)

cat("\n\n## Stability of Top Features\n\n")
vthemes::subchunkify(
  plot_fun(plt_ls[[2]]), i = sprintf("%s-%s", fpath, 2), 
  fig_height = fig_height, fig_width = fig_width
)

# get top 25 genes
top_genes <- ranking_df %>%
  dplyr::group_by(fi, var) %>%
  dplyr::summarise(mean_rank = mean(rank)) %>%
  dplyr::ungroup() %>%
  dplyr::arrange(mean_rank) %>%
  dplyr::group_by(fi) %>%
  dplyr::mutate(rank = 1:dplyr::n()) %>%
  dplyr::ungroup()

top_genes_table <- top_genes %>%
  dplyr::filter(rank <= 25) %>%
  dplyr::left_join(
    y = data.frame(var = 0:(ncol(X) - 1), 
                   gene = colnames(X)),
    by = "var"
  ) %>%
  dplyr::mutate(
    value = sprintf("%s (%s)", gene, round(mean_rank, 2))
  ) %>%
  tidyr::pivot_wider(
    id_cols = rank, 
    names_from = fi, 
    values_from = value
  ) %>%
  dplyr::rename(
    "Rank" = "rank"
  )

# write.csv(
#   top_genes_table, file.path(tables_dir, "tcga_top_genes_table.csv"),
#   row.names = FALSE, quote = FALSE
# )

stable_genes_kable <- stability_df %>%
  dplyr::group_by(fi) %>%
  dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
  tibble::as_tibble()

cat("\n\n## Table of Top Genes \n\n")
if (params$for_paper) {
  top_genes_kable <- top_genes_table %>%
    vthemes::pretty_kable(format = "latex")#, longtable = TRUE) %>%
    # kableExtra::kable_styling(latex_options = "repeat_header")
} else {
  top_genes_kable <- top_genes_table %>%
    vthemes::pretty_kable() %>%
    kableExtra::kable_styling(latex_options = "repeat_header")
  vthemes::subchunkify(
    top_genes_kable, i = sprintf("%s-table", fpath)
  )
}

cat("\n\n## Prediction using Top Features\n\n")
pred_fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "pred_results.csv"
)
pred_results <- data.table::fread(pred_fname) %>%
  tibble::as_tibble() %>%
  dplyr::mutate(method = fi) %>%
  tidyr::pivot_wider(names_from = metric, values_from = metric_value)
plt <- pred_results %>%
  dplyr::filter(k <= 25) %>%
  plot_metrics(
    metric = c("rocauc", "accuracy"),
    x_str = "k",
    facet_str = NULL,
    point_size = point_size,
    line_size = line_size,
    errbar_width = errbar_width,
    manual_color_palette = manual_color_palette,
    show_methods = show_methods,
    method_labels = method_labels,
    alpha_values = alpha_values,
    custom_theme = custom_theme,
    inside_legend = FALSE
  ) +
  ggplot2::labs(x = "Number of Top Features (k)", y = "Mean Prediction Performance", 
                color = "Method", alpha = "Method")
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "tcga_brca_prediction.pdf"),
                  width = fig_width * 1.3, height = fig_height * .85)
} else {
  vthemes::subchunkify(
    plot_fun(plt), i = sprintf("%s-pred", fpath), 
    fig_height = fig_height * .8, fig_width = fig_width * 1.5
  )
}
```

```{r eval = TRUE, results = "asis"}
if (params$for_paper) {
  manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399")
  show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  
  # ccle rnaseq
  fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X_ccle <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
  results_ccle <- data.table::fread(fname) %>%
    dplyr::mutate(
      fi = ifelse(fi == "MDI+", "MDI+_ridge", fi)
    )
  
  # tcga
  fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X_tcga <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_tcga_cleaned.csv")
  results_tcga <- data.table::fread(fname)
  
  # join results
  keep_cols <- c("rep", "y_task", "model", "fi", "index", "var", "importance")
  results <- dplyr::bind_rows(
    results_ccle %>% dplyr::select(tidyselect::all_of(keep_cols)),
    results_tcga %>% dplyr::mutate(y_task = "TCGA-BRCA") %>% dplyr::select(tidyselect::all_of(keep_cols))
  ) %>%
    tibble::as_tibble()
  plt_ls <- plot_top_stability(
    results,
    group_id = "y_task",
    top_r = 10,
    show_max_features = 5,
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = rev(manual_color_palette),
    show_methods = rev(show_methods),
    method_labels = rev(method_labels)
  )
  
  plt <- plt_ls$Summary +
    ggplot2::labs(
      y = "Drug",
      x = "Number of Distinct Features in Top 10 Across 32 Training-Test Splits"
    ) +
    ggplot2::scale_x_continuous(
      n.breaks = 6
    )
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_top10.pdf"), 
                        width = fig_width * 1.5, height = fig_height * 1.2)
  
  keep_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
  manual_color_palette_small <- manual_color_palette[show_methods %in% keep_methods]
  show_methods_small <- show_methods[show_methods %in% keep_methods]
  method_labels_small <- method_labels[show_methods %in% keep_methods]
  
  plt_tcga <- plot_top_stability(
    results_tcga,
    top_r = 10,
    show_max_features = 5,
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = manual_color_palette_small,
    show_methods = show_methods_small,
    method_labels = method_labels_small
  )$`Stability of Top Features` +
    ggplot2::labs(title = "TCGA-BRCA") + 
    ggplot2::theme(axis.title.y = ggplot2::element_blank())
  
  plt_ls <- plot_top_stability(
    results_ccle,
    group_id = "y_task",
    top_r = 10,
    show_max_features = 5,
    # varnames = colnames(X),
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = manual_color_palette_small[show_methods_small != "MDI+_logistic_logloss"],
    show_methods = show_methods_small[show_methods_small != "MDI+_logistic_logloss"],
    method_labels = method_labels_small[show_methods_small != "MDI+_logistic_logloss"]
  )
  
  # keep_drugs <- c("Panobinostat", "Lapatinib", "L-685458")
  keep_drugs <- c("PD-0325901", "Panobinostat", "L-685458")
  small_plt_ls <- list()
  for (drug in keep_drugs) {
    plt <- plt_ls[[drug]]
    if (drug != keep_drugs[1]) {
      plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
    }
    small_plt_ls[[drug]] <- plt +
      ggplot2::guides(fill = "none")
  }
  small_plt_ls[["TCGA-BRCA"]] <- plt_tcga
  # small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + small_plt_ls[[4]] +
  #   patchwork::plot_layout(nrow = 1, ncol = 4, guides = "collect")
  plt <- patchwork::wrap_plots(small_plt_ls, nrow = 1, ncol = 4, guides = "collect") &
    ggplot2::theme(
      plot.title = ggplot2::element_text(hjust = 0.5),
      panel.background = ggplot2::element_rect(fill = "white"),
      panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
    )
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_select_features.pdf"),
                  width = fig_width * 1, height = fig_height * 0.5)
  
  plt <- small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + patchwork::plot_spacer() + small_plt_ls[[4]] +
    patchwork::plot_layout(nrow = 1, widths = c(1, 1, 1, 0.1, 1), guides = "collect")
  # plt
  # grid::grid.draw(
  #   grid::linesGrob(x = grid::unit(c(0.68, 0.68), "npc"),
  #                   y = grid::unit(c(0.06, 0.98), "npc"))
  # )
  # # ggplot2::ggsave(
  # #   file.path(figures_dir, "casestudy_stability_select_features.pdf"),
  # #   units = "in", width = 14, height = 4
  # # )
}

```


# Misspecified Models {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.35)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp_omitted_vars", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "misspecified_regression_sims", 
                  sprintf("misspecified_regression_sims_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Varying Sparsity {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.7)

y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("juvenile", "splicing")

remove_x_axis_models <- c("splicing")
keep_legend_x_models <- c("splicing")
keep_legend_y_models <- c("linear", "linear_lss_3m_2r")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    x_str <- ifelse(y_model == "linear", "s", "m")
    vary_param_name <- sprintf("heritability_%s", x_str)
    fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_sparsity.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = x_str,
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = toupper(x_str), y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "varying_sparsity", 
                  sprintf("regression_sims_sparsity_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = 3
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = x_str,
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = toupper(x_str), y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Varying \# Features {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.4)

vary_param_name <- "heritability_sample_col_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("ccle_rnaseq")

remove_x_axis_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r")
keep_legend_x_models <- c("ccle_rnaseq")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_p.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_col_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Number of Features", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "varying_p", 
                  sprintf("regression_sims_vary_p_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_col_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Number of Features", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Prediction Results {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
fpaths <- c(
  "mdi_plus.prediction_sims.ccle_rnaseq_regression-",
  "mdi_plus.prediction_sims.enhancer_classification-",
  "mdi_plus.prediction_sims.splicing_classification-",
  "mdi_plus.prediction_sims.juvenile_classification-",
  "mdi_plus.prediction_sims.tcga_brca_classification-"
)

keep_models <- c("RF", "RF-ridge", "RF-lasso", "RF-logistic")
prediction_metrics <- c(
  "r2", "explained_variance", "mean_squared_error", "mean_absolute_error",
  "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
)

results_ls <- list()
for (fpath in fpaths) {
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  results <- data.table::fread(fname) %>%
    reformat_results(prediction = TRUE)
  
  if (!("y_task" %in% colnames(results))) {
    results <- results %>%
      dplyr::mutate(y_task = "Results")
  }
  
  plt_df <- results %>%
    tidyr::pivot_longer(
      cols = tidyselect::any_of(prediction_metrics), 
      names_to = "metric", values_to = "value"
    ) %>% 
    dplyr::group_by(y_task, model, metric) %>%
    dplyr::summarise(
      mean = mean(value),
      sd = sd(value),
      n = dplyr::n()
    ) %>%
    dplyr::mutate(
      se = sd / sqrt(n)
    )
  
  if (fpath == "mdi_plus.prediction_sims.ccle_rnaseq_regression") {
    tab <- plt_df %>%
      dplyr::mutate(
        value = sprintf("%.3f (%.3f)", mean, se)
      ) %>%
      tidyr::pivot_wider(id_cols = model, names_from = "metric", values_from = "value") %>%
      dplyr::arrange(dplyr::desc(accuracy)) %>%
      dplyr::select(model, accuracy, f1, rocauc, prauc, precision, recall) %>%
      vthemes::pretty_kable(format = "latex")
  }
  
  if (fpath != "mdi_plus.prediction_sims.ccle_rnaseq_regression-") {
    results_ls[[fpath]] <- plt_df %>%
      dplyr::filter(model %in% c("RF", "RF-logistic")) %>%
      dplyr::select(model, metric, mean) %>%
      tidyr::pivot_wider(id_cols = metric, names_from = "model", values_from = "mean") %>%
      dplyr::mutate(diff = `RF-logistic` - RF,
                    percent_diff = (`RF-logistic` - RF) / abs(RF) * 100)
  } else {
    results_ls[[fpath]] <- plt_df %>%
      dplyr::filter(model %in% c("RF", "RF-ridge")) %>%
      dplyr::select(y_task, model, metric, mean) %>%
      tidyr::pivot_wider(id_cols = c(y_task, metric), names_from = "model", values_from = "mean") %>%
      dplyr::mutate(diff = `RF-ridge` - RF,
                    percent_diff = (`RF-ridge` - RF) / abs(RF) * 100)
  }
}

plt_reg <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
  dplyr::filter(metric == "r2") %>%
  dplyr::filter(RF > 0.1) %>%
  dplyr::mutate(
    metric = ifelse(metric == "r2", "R-squared", metric)
  ) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
  ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
  ggplot2::geom_point(size = 3) +
  ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
  ggrepel::geom_label_repel(fill = "white") +
  ggplot2::facet_wrap(~ metric, scales = "free") +
  custom_theme

plt_reg_all <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
  dplyr::filter(metric == "r2") %>%
  dplyr::mutate(
    metric = ifelse(metric == "r2", "R-squared", metric)
  ) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
  ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
  ggplot2::geom_point(size = 3) +
  ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
  ggrepel::geom_label_repel(fill = "white") +
  ggplot2::facet_wrap(~ metric, scales = "free") +
  custom_theme

if (params$for_paper) {
  ggplot2::ggsave(
    plot = plt_reg_all, 
    file.path(figures_dir, "prediction_results_appendix.pdf"),
    units = "in", width = 8, height = fig_height * 0.75
  )
}

plt_ls <- list(reg = plt_reg)
for (m in c("f1", "prauc")) {
  plt_df <- dplyr::bind_rows(results_ls, .id = "dataset") %>%
    dplyr::filter(metric == m) %>%
    dplyr::mutate(
      dataset = stringr::str_remove(dataset, "^mdi_plus\\.prediction_sims\\.") %>%
        stringr::str_remove("_classification-$") %>%
        stringr::str_to_title() %>%
        stringr::str_replace("Tcga_brca", "TCGA BRCA"),
      diff = ifelse(metric == "logloss", -diff, diff),
      percent_diff = ifelse(metric == "logloss", -percent_diff, percent_diff),
      metric = forcats::fct_recode(metric,
                                   "Accuracy" = "accuracy",
                                   "F1" = "f1",
                                   "Negative Log-loss" = "logloss",
                                   "AUPRC" = "prauc",
                                   "AUROC" = "rocauc")
    )
  plt_ls[[m]] <- plt_df %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = RF, y = percent_diff, label = dataset) +
    ggplot2::labs(x = sprintf("RF Test %s", plt_df$metric[1]), 
                  y = "% Change Using RF+ (logistic)") +
    ggplot2::geom_point(size = 3) +
    ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
    ggrepel::geom_label_repel(point.padding = 0.5, fill = "white") +
    ggplot2::facet_wrap(~ metric, scales = "free", nrow = 1) +
    custom_theme
}

plt_ls[[3]] <- plt_ls[[3]] +
  ggplot2::theme(axis.title.y = ggplot2::element_blank())

plt <- plt_ls[[1]] + patchwork::plot_spacer() + plt_ls[[2]] + plt_ls[[3]] +
  patchwork::plot_layout(nrow = 1, widths = c(1, 0.3, 1, 1))

# plt
# grid::grid.draw(
#   grid::linesGrob(x = grid::unit(c(0.36, 0.36), "npc"),
#                   y = grid::unit(c(0.06, 0.98), "npc"))
# )
# # ggplot2::ggsave(
# #   file.path(figures_dir, "prediction_results_main.pdf"),
# #   units = "in", width = 14, height = fig_height * 0.75
# # )

vthemes::subchunkify(plt, i = chunk_idx,
                     fig_height = fig_height, fig_width = fig_width * 1.3)
chunk_idx <- chunk_idx + 1
```


# MDI+ Modeling Choices {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c(manual_color_palette_choices[method_labels_choices == "MDI"],
                          manual_color_palette_choices[method_labels_choices == "MDI-oob"],
                          "#AF4D98", "#FFD92F", "#FC8D62", "black")
show_methods <- c("MDI_RF", "MDI-oob_RF", "GMDI_raw_RF", "GMDI_loo_RF", "GMDI_ols_raw_loo_RF", "GMDI_ridge_raw_loo_RF")
method_labels <- c("MDI", "MDI-oob", "MDI+ (raw only)", "MDI+ (loo only)", "MDI+ (raw+loo only)", "MDI+ (ridge+raw+loo)")
alpha_values <- rep(1, length(method_labels))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
min_samples_per_leafs <- c(5, 1)

remove_x_axis_models <- c("enhancer")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    plt_ls <- list()
    for (min_samples_per_leaf in min_samples_per_leafs) {
      cat(sprintf("\n\n#### min_samples_per_leaf = %s \n\n", min_samples_per_leaf))
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, 
                         sprintf("mdi_plus.other_regression_sims.modeling_choices_min_samples%s.%s", min_samples_per_leaf, sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (!file.exists(sprintf("%s.csv", fname))) {
        next
      }
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12)
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (x_model %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "modeling_choices",
                    sprintf("regression_sims_choices_min_samples%s_%s_%s_%s_errbars.pdf",
                            min_samples_per_leaf, y_model, x_model, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```


# MDI+ GLM/Metric Choices {.tabset .tabset-vmodern}

## Held-out Test Prediction Scores {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("#A5AE9E", "black", "brown")
show_methods <- c("RF", "RF-ridge", "RF-lasso")
method_labels <- c("RF", "RF+ridge", "RF+lasso")
alpha_values <- rep(1, length(method_labels))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
prediction_metrics <- c(
  "r2" #, "mean_absolute_error"
  # "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
)

remove_x_axis_models <- c("enhancer")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.regression_prediction_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results(prediction = TRUE)
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results(prediction = TRUE)
      saveRDS(results, sprintf("%s.rds", fname))
    }
    for (m in prediction_metrics) {
      metric_name <- dplyr::case_when(
        m == "r2" ~ "R-squared",
        m == "mae" ~ "Mean Absolute Error"
      )
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12)
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (x_model %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "glm_metric_choices",
                    sprintf("regression_sims_glm_metric_choices_prediction_%s_%s_%s_errbars.pdf",
                            y_model, x_model, m)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```

## Stability Scores - Regression {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "black", "brown", "brown", "gray", 
                                 manual_color_palette_choices[method_labels_choices == "MDI"],
                                 manual_color_palette_choices[method_labels_choices == "MDI-oob"])
show_methods <- c("GMDI_ridge_r2_RF", "GMDI_ridge_neg_mae_RF", "GMDI_lasso_r2_RF", "GMDI_lasso_neg_mae_RF", "GMDI_ensemble_RF", "MDI_RF", "MDI-oob_RF")
method_labels <- c("MDI+ (ridge, r-squared)", "MDI+ (ridge, MAE)", "MDI+ (lasso, r-squared)", "MDI+ (lasso, MAE)", "MDI+ (ensemble)", "MDI", "MDI-oob")
manual_linetype_palette <- c(1, 2, 1, 2, 1, 1, 1)
alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
# y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")
# stability_metrics <- c("tauAP", "RBO")
stability_metrics <- c("RBO")

remove_x_axis_models <- metric
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_metrics <- metric

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.regression_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    plt_ls <- list()
    for (m in c(metric, stability_metrics)) {
      metric_name <- dplyr::case_when(
        m == "rocauc" ~ "AUROC",
        m == "prauc" ~ "PRAUC",
        TRUE ~ m
      )
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              linetype_str = "method",
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = FALSE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method", linetype = "Method"
            ) +
            ggplot2::scale_linetype_manual(
              values = manual_linetype_palette, labels = method_labels
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12),
              legend.key.width = ggplot2::unit(1.9, "cm")
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (m %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models) &
                (metric %in% keep_legend_metrics))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
          }
          plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
        }
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            linetype_str = "method",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method", linetype = "Method",
            title = sprintf("%s", sim_title)
          ) +
          ggplot2::scale_linetype_manual(
            values = manual_linetype_palette, labels = method_labels
          ) +
          ggplot2::theme(
            legend.key.width = ggplot2::unit(1.9, "cm")
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.5)
        chunk_idx <- chunk_idx + 1
      }
    }
    if (params$for_paper) {
      nrows <- length(c(metric, stability_metrics))
      plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
      ggplot2::ggsave(
        file.path(figures_dir, "glm_metric_choices",
                  sprintf("regression_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
                          y_model, x_model)),
        plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
      )
    }
  }
}

```

## Stability Scores - Classification {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("#9B5DFF", "#9B5DFF", "black",
                                 manual_color_palette_choices[method_labels_choices == "MDI"],
                                 manual_color_palette_choices[method_labels_choices == "MDI-oob"])
show_methods <- c("GMDI_logistic_ridge_logloss_RF", "GMDI_logistic_ridge_auroc_RF", "GMDI_ridge_r2_RF", "MDI_RF", "MDI-oob_RF")
method_labels <- c("MDI+ (logistic, log-loss)", "MDI+ (logistic, AUROC)", "MDI+ (ridge, r-squared)", "MDI", "MDI-oob")
manual_linetype_palette <- c(1, 2, 1, 1, 1)
alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
legend_position <- c(0.63, 0.4)

vary_param_name <- "frac_label_corruption_sample_row_n"
# y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
y_models <- c("logistic", "hier_poly_3m_2r_logistic")
x_models <- c("juvenile", "splicing")
# stability_metrics <- c("tauAP", "RBO")
stability_metrics <- "RBO"

remove_x_axis_models <- metric
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_metrics <- metric

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.classification_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    plt_ls <- list()
    for (m in c(metric, stability_metrics)) {
      metric_name <- dplyr::case_when(
        m == "rocauc" ~ "AUROC",
        m == "prauc" ~ "PRAUC",
        TRUE ~ m
      )
      if (params$for_paper) {
        for (h in frac_label_corruptions) {
          plt <- results %>%
            dplyr::filter(frac_label_corruption_name == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              linetype_str = "method",
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = FALSE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method", linetype = "Method"
            ) +
            ggplot2::scale_linetype_manual(
              values = manual_linetype_palette, labels = method_labels
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12),
              legend.key.width = ggplot2::unit(1.9, "cm")
            )
          if (h != frac_label_corruptions[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (m %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models) &
                (metric %in% keep_legend_metrics))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
          }
          plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
        }
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "frac_label_corruption_name",
            linetype_str = "method",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method", linetype = "Method",
            title = sprintf("%s", sim_title)
          ) +
          ggplot2::scale_linetype_manual(
            values = manual_linetype_palette, labels = method_labels
          ) +
          ggplot2::theme(
            legend.key.width = ggplot2::unit(1.9, "cm")
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.5)
        chunk_idx <- chunk_idx + 1
      }
    }
    if (params$for_paper) {
      nrows <- length(c(metric, stability_metrics))
      plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
      ggplot2::ggsave(
        file.path(figures_dir, "glm_metric_choices",
                  sprintf("classification_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
                          y_model, x_model)),
        plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
      )
    }
  }
}

```

+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/feature_importance/readme.md b/feature_importance/readme.md
new file mode 100644
index 0000000..188a2f1
--- /dev/null
+++ b/feature_importance/readme.md
@@ -0,0 +1,148 @@
+# Feature Importance Simulations Pipeline
+
+This is a basic template to run extensive simulation studies for benchmarking a new feature importance method. The main python script is `01_run_importance_simulations.py` (which is largely based on `../01_fit_models.py`). To run these simulations, follow the three main steps below:
+
+1. Take the feature importance method of interest, and wrap it in a function that has the following structure:
+ - Inputs:
+ - `X`: A data frame of covariates/design matrix.
+ - `y`: Response vector. [Note that this argument is required even if it is not used by the feature importance method.]
+ - `fit`: A fitted estimator (e.g., a fitted RandomForestRegressor).
+ - Additional input arguments are allowed, but `X`, `y`, and `fit` are required at a minimum.
+ - Output:
+ - A data frame with at least the columns `var` and `importance`, containing the variable ID and the importance scores respectively. Additional columns are also permitted.
+ - For examples of this feature importance wrapper, see `scripts/competing_methods.py`.
+2. Update configuration files (in `fi_config/`) to set the data-generating process(es), prediction models, and feature importance estimators to run in the simulation. See below for additional information and examples of these configuration files.
+3. Run `01_run_importance_simulations.py` and pass in the appropriate commandline arguments. See below for additional information and examples.
+ - If `--create_rmd` is passed in as an argument in step 3, this will automatically generate an html document with some basic visualization summaries of the results. These results are rendered using R Markdown via `rmd/simulation_results.Rmd` and are saved in the results folder that was specified in step 3.
+
+Notes:
+ - To apply the feature importance method to real data (or any setting where the true support/signal features are unknown), one can use `02_run_importance_real_data.py` instead of `01_run_importance_simulations.py`.
+ - To evaluate the prediction accuracy of the model fits, see `03_run_prediction_simulations.py` and `04_run_prediction_real_data.py` for simulated and real data, respectively.
+
+Additional details for steps 2 and 3 are provided below.
+
+
+## Creating the config files (step 2)
+
+For a starter template, see the `fi_config/test` folder. There are two necessary files:
+
+- `dgp.py`: Script specifying the data-generating process under study. The following variables must be provided:
+ - `X_DGP`: Function to generate X data.
+ - `X_PARAMS_DICT`: Dictionary of named arguments to pass into the `X_DGP` function.
+ - `Y_DGP`: Function to generate y data.
+ - `Y_PARAMS_DICT`: Dictionary of named arguments to pass into the `Y_DGP` function.
+ - `VARY_PARAM_NAME`: Name of argument (typically in `X_DGP` or `Y_DGP`) to vary across. Note that it is also possible to vary across an argument in an `ESTIMATOR` in very basic simulation setups. This can also be a vector of parameters to vary over in a grid.
+ - `VARY_PARAM_VALS`: Dictionary of named arguments for the `VARY_PARAM_NAME` to take on in the simulation experiment. Note that the value can be any python object, but make sure to keep the key simple for naming and plotting purposes.
+- `models.py`: Script specifying the prediction methods and feature importance estimators under study. The following variables must be provided:
+ - `ESTIMATORS`: List of prediction methods to fit. Elements should be of class `ModelConfig`.
+ - Note that the class passed into `ModelConfig` should have a `fit` method (e.g., like sklearn models).
+ - Additional arguments to pass to the model class can be specified in a dictionary using the `other_params` arguments in `ModelConfig().
+ - `FI_ESTIMATORS`: List of feature importance methods to fit. Elements should be of class `FIModelConfig`.
+ - Note that the function passed into `FIModelConfig` should take in the arguments `X`, `y`, `fit` at a minimum. For examples, see `scripts/competing_methods.py`.
+ - Additional arguments to pass to the feature importance function can be specified in a dictionary using the `other_params` argument in `FIModelConfig()`.
+ - Pair up a prediction model and feature importance estimator by using the same `model_type` ID in both.
+ - Note that by default, higher values from the feature importance method are assumed to indicate higher importance. If higher values indicate lower importance, set `ascending=False` in `FIModelConfig()`.
+ - If a feature importance estimator requires sample splitting (outside of the feature importance function call), use the `splitting_strategy` argument to specify the type of splitting strategy (e.g., `train-test`).
+
+For an example of the fi_config files used for real data case studies, see the `fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/` folder. Like before, there are two necessary files: `dgp.py` and `models.py`. `models.py` follows the same structure and requirements as above. `dgp.py` is a script that defines two variables, `X_PATH` and `Y_PATH`, specifying the file paths of the covariate/feature data `X` and the response data `y`, respectively.
+
+
+## Running the simulations (step 3)
+
+**For running simulations to evaluate feature importance rankings:** `01_run_importance_simulations.py`
+
+- Command Line Arguments:
+ - Simulation settings:
+ - `--nreps`: Number of replicates.
+ - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file.
+ - `--fi_model`: Name of feature importance estimator to run. Default (`None`) uses all feature importance estimators specified in `models.py` config file.
+ - `--config`: Name of fi_config folder (and title of the simulation experiment).
+ - `--omit_vars`: (Optional) Comma-separated string of feature indices to omit (as unobserved variables). More specifically, these features may be used in generating the response *y* but are omitted from the *X* used in training/evaluating the prediction model and feature importance estimator.
+ - Computational settings:
+ - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle).
+ - `--ignore_cache`: Whether or not to ignore cached results.
+ - `--verbose`: Whether or not to print messages.
+ - `--parallel`: Whether or not to run replicates in parallel.
+ - `--parallel_id`: ID for parallelization.
+ - `--n_cores`: Number of cores if running in parallel. Default uses all available cores.
+ - `--split_seed`: Seed for data splitting.
+ - `--results_path`: Path to save results. Default is `./results`.
+ - R Markdown output options:
+ - `--create_rmd`: Whether or not to output R Markdown-generated html file with summary of results.
+ - `--show_vars`: Max number of features to show in rejection probability plots in the R Markdown. Default (`None`) is to show all variables.
+- Example usage in command line:
+```
+python 01_run_importance_simulations.py --nreps 100 --config test --split_seed 331 --ignore_cache --create_rmd
+```
+
+**For running feature importance methods on real data:** `02_run_importance_real_data.py`
+
+- Command Line Arguments:
+ - Simulation settings:
+ - `--nreps`: Number of replicates (or times to run method on the given data).
+ - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file.
+ - `--fi_model`: Name of feature importance estimator to run. Default (`None`) uses all feature importance estimators specified in `models.py` config file.
+ - `--config`: Name of fi_config folder (and title of the simulation experiment).
+ - `--response_idx`: (Optional) Name of response column to use if response data *y* is a matrix or multi-task. If not provided, independent regression/classification problems are fitted for every column of *y* separately. If *y* is not a matrix, this argument should be ignored and is unused.
+ - Computational settings:
+ - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle).
+ - `--ignore_cache`: Whether or not to ignore cached results.
+ - `--verbose`: Whether or not to print messages.
+ - `--parallel`: Whether or not to run replicates in parallel.
+ - `--parallel_id`: ID for parallelization.
+ - `--n_cores`: Number of cores if running in parallel. Default uses all available cores.
+ - `--split_seed`: Seed for data splitting.
+ - `--results_path`: Path to save results. Default is `./results`.
+- Example usage in command line:
+```
+python 02_run_simulations.py --nreps 1 --config test --split_seed 331 --ignore_cache
+```
+
+**For running simulations to evaluate prediction accuracy of the model fits:** `03_run_prediction_simulations.py`
+
+- Command Line Arguments:
+ - Simulation settings:
+ - `--nreps`: Number of replicates.
+ - `--mode`: One of 'regression', 'binary_classification', or 'binary_classification'.
+ - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file.
+ - `--config`: Name of fi_config folder (and title of the simulation experiment).
+ - `--omit_vars`: (Optional) Comma-separated string of feature indices to omit (as unobserved variables). More specifically, these features may be used in generating the response *y* but are omitted from the *X* used in training/evaluating the prediction model and feature importance estimator.
+ - Computational settings:
+ - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle).
+ - `--ignore_cache`: Whether or not to ignore cached results.
+ - `splitting_strategy`: One of 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata', indicating how to split the data into training and test.
+ - `--verbose`: Whether or not to print messages.
+ - `--parallel`: Whether or not to run replicates in parallel.
+ - `--parallel_id`: ID for parallelization.
+ - `--n_cores`: Number of cores if running in parallel. Default uses all available cores.
+ - `--split_seed`: Seed for data splitting.
+ - `--results_path`: Path to save results. Default is `./results`.
+- Example usage in command line:
+```
+python 03_run_prediction_simulations.py --nreps 100 --config mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_linear_dgp --mode regression --split_seed 331 --ignore_cache --nosave_cols prediction_model
+```
+
+**For running prediction methods on real data:** `04_run_prediction_real_data.py`
+
+- Command Line Arguments:
+ - Simulation settings:
+ - `--nreps`: Number of replicates.
+ - `--mode`: One of 'regression', 'binary_classification', or 'binary_classification'.
+ - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file.
+ - `--config`: Name of fi_config folder (and title of the simulation experiment).
+ - `--response_idx`: (Optional) Name of response column to use if response data *y* is a matrix or multi-task. If not provided, independent regression/classification problems are fitted for every column of *y* separately. If *y* is not a matrix, this argument should be ignored and is unused.
+ - `--subsample_n`: (Optional) Integer indicating max number of samples to use in training prediction model. If None, no subsampling occurs.
+ - Computational settings:
+ - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle).
+ - `--ignore_cache`: Whether or not to ignore cached results.
+ - `splitting_strategy`: One of 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata', indicating how to split the data into training and test.
+ - `--verbose`: Whether or not to print messages.
+ - `--parallel`: Whether or not to run replicates in parallel.
+ - `--parallel_id`: ID for parallelization.
+ - `--n_cores`: Number of cores if running in parallel. Default uses all available cores.
+ - `--split_seed`: Seed for data splitting.
+ - `--results_path`: Path to save results. Default is `./results`.
+- Example usage in command line:
+```
+python 04_run_prediction_real_data.py --nreps 10 --config mdi_plus.prediction_sims.enhancer_classification- --mode binary_classification --split_seed 331 --ignore_cache --nosave_cols prediction_model
+```
\ No newline at end of file
diff --git a/feature_importance/rmd/simulation_results.Rmd b/feature_importance/rmd/simulation_results.Rmd
new file mode 100644
index 0000000..8aef415
--- /dev/null
+++ b/feature_importance/rmd/simulation_results.Rmd
@@ -0,0 +1,717 @@
+---
+title: "Simulation Results"
+author: ""
+date: "`r format(Sys.time(), '%B %d, %Y')`"
+output: vthemes::vmodern
+params:
+ results_dir:
+ label: "Results directory"
+ value: "results/test"
+ vary_param_name:
+ label: "Name of varying parameter"
+ value: "sample_row_n"
+ seed:
+ label: "Seed"
+ value: 0
+ keep_vars:
+ label: "Max variables to keep in rejection probability plots"
+ value: 100
+ rm_fi:
+ label: "Feature importance methods to omit"
+ value: NULL
+ abridged:
+ label: "Abridged Document"
+ value: TRUE
+---
+
+```{r setup, include=FALSE}
+knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)
+
+library(magrittr)
+chunk_idx <- 1
+
+# set parameters
+results_dir <- params$results_dir
+vary_param_name <- params$vary_param_name
+if (is.null(vary_param_name)) {
+ vary_param_name_vec <- NULL
+} else if (stringr::str_detect(vary_param_name, ";")) {
+ vary_param_name_vec <- stringr::str_split(vary_param_name, "; ")[[1]]
+ vary_param_name <- paste(vary_param_name_vec, collapse = "_")
+ if (length(vary_param_name_vec) != 2) {
+ stop("Rmarkdown report has not been configured to show results when >2 parameters are being varied.")
+ }
+} else {
+ vary_param_name_vec <- NULL
+}
+seed <- params$seed
+if (!is.null(params$keep_vars)) {
+ keep_vars <- 0:params$keep_vars
+} else {
+ keep_vars <- params$keep_vars
+}
+abridged <- params$abridged
+```
+
+```{r helper-functions, echo = FALSE}
+# reformat results
+reformat_results <- function(results) {
+ if (!is.null(params$rm_fi)) {
+ results <- results %>%
+ dplyr::filter(!(fi %in% params$rm_fi))
+ }
+ results_grouped <- results %>%
+ dplyr::group_by(index) %>%
+ tidyr::nest(fi_scores = var:(tidyselect::last_col())) %>%
+ dplyr::ungroup() %>%
+ dplyr::select(-index) %>%
+ # join fi+model to get method column
+ tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>%
+ dplyr::mutate(
+ # get rid of duplicate RF in r2f method name
+ method = ifelse(stringr::str_detect(method, "^r2f.*RF$"),
+ stringr::str_remove(method, "\\_RF$"), method)
+ )
+ return(results_grouped)
+}
+
+# plot metrics (mean value across repetitions with error bars)
+plot_metrics <- function(results, vary_param_name, vary_param_name_vec,
+ show_errbars = TRUE) {
+ if (!is.null(vary_param_name_vec)) {
+ vary_param_name <- vary_param_name_vec
+ }
+ plt_df <- results %>%
+ dplyr::select(rep, method,
+ tidyselect::all_of(c(paste0(vary_param_name, "_name"),
+ metrics))) %>%
+ tidyr::pivot_longer(
+ cols = tidyselect::all_of(metrics), names_to = "metric"
+ ) %>%
+ dplyr::group_by(
+ dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))),
+ method, metric
+ ) %>%
+ dplyr::summarise(mean = mean(value),
+ sd = sd(value) / sqrt(dplyr::n()),
+ .groups = "keep")
+
+ if (is.null(vary_param_name_vec)) {
+ if (length(unique(plt_df[[paste0(vary_param_name, "_name")]])) == 1) {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = method, y = mean) +
+ ggplot2::facet_wrap(~ metric, scales = "free_y", nrow = 1, ncol = 2) +
+ ggplot2::geom_point() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = "Method")
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = method, ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ } else {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]],
+ y = mean, color = method) +
+ ggplot2::facet_wrap(~ metric, scales = "free_y", nrow = 1, ncol = 2) +
+ ggplot2::geom_point() +
+ ggplot2::geom_line() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = vary_param_name)
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]],
+ ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ }
+ } else {
+ plt <- plt_df %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]],
+ y = mean, color = method) +
+ ggplot2::facet_grid(metric ~ .data[[paste0(vary_param_name[1], "_name")]]) +
+ ggplot2::geom_point() +
+ ggplot2::geom_line() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = vary_param_name[2])
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]],
+ ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ }
+ return(plt)
+}
+
+# plot restricted auroc/auprc
+plot_restricted_metrics <- function(results, vary_param_name,
+ vary_param_name_vec,
+ quantiles = c(.1, .2, .3, .4),
+ show_errbars = TRUE) {
+ if (!is.null(vary_param_name_vec)) {
+ vary_param_name <- vary_param_name_vec
+ }
+ results <- results %>%
+ dplyr::select(rep, method, fi_scores,
+ tidyselect::all_of(c(paste0(vary_param_name, "_name")))) %>%
+ dplyr::mutate(
+ vars_ordered = purrr::map(
+ fi_scores,
+ function(fi_df) {
+ fi_df %>%
+ dplyr::filter(!is.na(cor_with_signal)) %>%
+ dplyr::arrange(-cor_with_signal) %>%
+ dplyr::pull(var)
+ }
+ )
+ )
+
+ plt_df_ls <- list()
+ for (q in quantiles) {
+ plt_df_ls[[as.character(q)]] <- results %>%
+ dplyr::mutate(
+ restricted_metrics = purrr::map2_dfr(
+ fi_scores, vars_ordered,
+ function(fi_df, ignore_vars) {
+ ignore_vars <- ignore_vars[1:round(q * length(ignore_vars))]
+ auroc_r <- fi_df %>%
+ dplyr::filter(!(var %in% ignore_vars)) %>%
+ yardstick::roc_auc(
+ truth = factor(true_support, levels = c("1", "0")), importance,
+ event_level = "first"
+ ) %>%
+ dplyr::pull(.estimate)
+ auprc_r <- fi_df %>%
+ dplyr::filter(!(var %in% ignore_vars)) %>%
+ yardstick::pr_auc(
+ truth = factor(true_support, levels = c("1", "0")), importance,
+ event_level = "first"
+ ) %>%
+ dplyr::pull(.estimate)
+ return(data.frame(restricted_auroc = auroc_r,
+ restricted_auprc = auprc_r))
+ }
+ )
+ ) %>%
+ tidyr::unnest(restricted_metrics) %>%
+ tidyr::pivot_longer(
+ cols = c(restricted_auroc, restricted_auprc), names_to = "metric"
+ ) %>%
+ dplyr::group_by(
+ dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))),
+ method, metric
+ ) %>%
+ dplyr::summarise(mean = mean(value),
+ sd = sd(value) / sqrt(dplyr::n()),
+ .groups = "keep") %>%
+ dplyr::ungroup()
+ }
+ plt_df <- purrr::map_dfr(plt_df_ls, ~.x, .id = ".threshold") %>%
+ dplyr::mutate(.threshold = as.numeric(.threshold))
+
+ if (is.null(vary_param_name_vec)) {
+ if (length(unique(plt_df[[paste0(vary_param_name, "_name")]])) == 1) {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = method, y = mean) +
+ ggplot2::facet_grid(metric ~ .threshold, scales = "free_y") +
+ ggplot2::geom_point() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = "Method")
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = method, ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ } else {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]],
+ y = mean, color = method) +
+ ggplot2::facet_grid(metric ~ .threshold, scales = "free_y") +
+ ggplot2::geom_point() +
+ ggplot2::geom_line() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = vary_param_name)
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]],
+ ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ }
+ } else {
+ plt <- plt_df %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]],
+ y = mean, color = method) +
+ ggplot2::facet_grid(metric + .threshold ~ .data[[paste0(vary_param_name[1], "_name")]]) +
+ ggplot2::geom_point() +
+ ggplot2::geom_line() +
+ vthemes::theme_vmodern() +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ ggplot2::labs(x = vary_param_name[2])
+ if (show_errbars) {
+ plt <- plt +
+ ggplot2::geom_errorbar(
+ mapping = ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]],
+ ymin = mean - sd, ymax = mean + sd),
+ width = 0
+ )
+ }
+ }
+ return(plt)
+}
+
+# plot true positive rate across # positives
+plot_tpr <- function(results, vary_param_name, vary_param_name_vec) {
+ if (!is.null(vary_param_name_vec)) {
+ vary_param_name <- vary_param_name_vec
+ }
+ if (is.null(results)) {
+ return(NULL)
+ }
+
+ plt_df <- results %>%
+ dplyr::mutate(
+ fi_scores = mapply(name = fi, scores_df = fi_scores,
+ function(name, scores_df) {
+ scores_df <- scores_df %>%
+ dplyr::mutate(
+ ranking = rank(-importance,
+ ties.method = "random")
+ ) %>%
+ dplyr::arrange(ranking) %>%
+ dplyr::mutate(
+ .tp = cumsum(true_support) / sum(true_support)
+ )
+ return(scores_df)
+ }, SIMPLIFY = FALSE)
+ ) %>%
+ tidyr::unnest(fi_scores) %>%
+ dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")),
+ rep, method, ranking, .tp) %>%
+ dplyr::group_by(
+ dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))),
+ method, ranking
+ ) %>%
+ dplyr::summarise(.tp = mean(.tp), .groups = "keep")
+
+ if (is.null(vary_param_name_vec)) {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = ranking, y = .tp, color = method) +
+ ggplot2::geom_line() +
+ ggplot2::facet_wrap(reformulate(paste0(vary_param_name, "_name"))) +
+ ggplot2::labs(x = "Top n", y = "True Positive Rate", fill = "Method") +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ vthemes::theme_vmodern()
+ } else {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = ranking, y = .tp, color = method) +
+ ggplot2::geom_line() +
+ ggplot2::facet_grid(
+ reformulate(paste0(vary_param_name[1], "_name"),
+ paste0(vary_param_name[2], "_name"))
+ ) +
+ ggplot2::labs(x = "Top n", y = "True Positives Rate", fill = "Method") +
+ vthemes::scale_color_vmodern(discrete = TRUE) +
+ vthemes::theme_vmodern()
+ }
+ return(plt)
+}
+
+# plot feature importances
+plot_feature_importance <- function(results, vary_param_name,
+ vary_param_name_vec,
+ keep_vars = NULL,
+ plot_type = c("boxplot", "bar")) {
+ if (!is.null(vary_param_name_vec)) {
+ vary_param_name <- vary_param_name_vec
+ }
+
+ plot_type <- match.arg(plot_type)
+ plt_df <- results %>%
+ tidyr::unnest(fi_scores)
+ if (plot_type == "bar") {
+ plt_df <- plt_df %>%
+ dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")),
+ rep, method, var, importance) %>%
+ dplyr::group_by(
+ dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))),
+ method, var
+ ) %>%
+ dplyr::summarise(mean_fi = mean(importance), .groups = "keep")
+ }
+ if (!is.null(keep_vars)) {
+ plt_df <- plt_df %>%
+ dplyr::filter(var %in% keep_vars)
+ }
+ plt_ls <- list()
+ if (is.null(vary_param_name_vec)) {
+ for (val in unique(plt_df[[paste0(vary_param_name, "_name")]])) {
+ if (plot_type == "bar") {
+ plt <- plt_df %>%
+ dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = mean_fi) +
+ ggplot2::geom_bar(stat = "identity", color = "grey98",
+ fill = "#00C5FF") +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val),
+ x = "Feature", y = "Mean Importance / Significance") +
+ vthemes::theme_vmodern()
+ } else if (plot_type == "boxplot") {
+ plt <- plt_df %>%
+ dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = importance, group = var) +
+ ggplot2::geom_boxplot() +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val),
+ x = "Feature", y = "Importance / Significance") +
+ vthemes::theme_vmodern()
+ }
+ plt_ls[[as.character(val)]] <- plt
+ }
+ } else {
+ for (val1 in unique(plt_df[[paste0(vary_param_name[1], "_name")]])) {
+ plt_ls[[as.character(val1)]] <- list()
+ for (val2 in unique(plt_df[[paste0(vary_param_name[2], "_name")]])) {
+ if (plot_type == "bar") {
+ plt <- plt_df %>%
+ dplyr::filter(
+ .data[[paste0(vary_param_name[1], "_name")]] == !!val1,
+ .data[[paste0(vary_param_name[2], "_name")]] == !!val2
+ ) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = mean_fi) +
+ ggplot2::geom_bar(stat = "identity", color = "grey98", fill = "#00C5FF") +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::labs(title = sprintf("%s = %s; %s = %s",
+ vary_param_name[1], val1,
+ vary_param_name[2], val2),
+ x = "Feature", y = "Mean Importance / Significance") +
+ vthemes::theme_vmodern()
+ } else if (plot_type == "boxplot") {
+ plt <- plt_df %>%
+ dplyr::filter(
+ .data[[paste0(vary_param_name[1], "_name")]] == !!val1,
+ .data[[paste0(vary_param_name[2], "_name")]] == !!val2
+ ) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = importance, group = var) +
+ ggplot2::geom_boxplot() +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::labs(title = sprintf("%s = %s; %s = %s",
+ vary_param_name[1], val1,
+ vary_param_name[2], val2),
+ x = "Feature", y = "Importance / Significance") +
+ vthemes::theme_vmodern()
+ }
+ plt_ls[[as.character(val1)]][[as.character(val2)]] <- plt
+ }
+ }
+ }
+
+ return(plt_ls)
+}
+
+# plot ranking heatmap
+plot_ranking_heatmap <- function(results, vary_param_name, vary_param_name_vec,
+ keep_vars = NULL) {
+ if (!is.null(vary_param_name_vec)) {
+ vary_param_name <- vary_param_name_vec
+ }
+
+ plt_df <- results %>%
+ dplyr::mutate(
+ fi_scores = mapply(name = fi, scores_df = fi_scores,
+ function(name, scores_df) {
+ scores_df <- scores_df %>%
+ dplyr::mutate(ranking = rank(-importance))
+ return(scores_df)
+ }, SIMPLIFY = FALSE)
+ ) %>%
+ tidyr::unnest(fi_scores) %>%
+ dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")),
+ rep, method, var, ranking, importance)
+
+ if (!is.null(keep_vars)) {
+ plt_df <- plt_df %>%
+ dplyr::filter(var %in% keep_vars)
+ }
+ plt_ls <- list()
+ if (is.null(vary_param_name_vec)) {
+ for (val in unique(plt_df[[paste0(vary_param_name, "_name")]])) {
+ plt <- plt_df %>%
+ dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = rep, fill = ranking, text = importance) +
+ ggplot2::geom_tile() +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::coord_cartesian(expand = FALSE) +
+ ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val),
+ x = "Feature", y = "Replicate", fill = "Ranking") +
+ vthemes::scale_fill_vmodern() +
+ vthemes::theme_vmodern()
+ plt_ls[[as.character(val)]] <- plt
+ }
+ } else {
+ for (val1 in unique(plt_df[[paste0(vary_param_name[1], "_name")]])) {
+ plt_ls[[as.character(val1)]] <- list()
+ for (val2 in unique(plt_df[[paste0(vary_param_name[2], "_name")]])) {
+ plt <- plt_df %>%
+ dplyr::filter(
+ .data[[paste0(vary_param_name[1], "_name")]] == !!val1,
+ .data[[paste0(vary_param_name[2], "_name")]] == !!val2
+ ) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = var, y = rep, fill = ranking, text = importance) +
+ ggplot2::geom_tile() +
+ ggplot2::facet_wrap(~ method, scales = "free",
+ ncol = 2,
+ nrow = ceiling(length(unique(plt_df$method)) / 2)) +
+ ggplot2::coord_cartesian(expand = FALSE) +
+ ggplot2::labs(title = sprintf("%s = %s; %s = %s",
+ vary_param_name[1], val1,
+ vary_param_name[2], val2),
+ x = "Feature", y = "Replicate", fill = "Ranking") +
+ vthemes::scale_fill_vmodern() +
+ vthemes::theme_vmodern()
+ plt_ls[[as.character(val1)]][[as.character(val2)]] <- plt
+ }
+ }
+ }
+ return(plt_ls)
+}
+
+# view results in Rmarkdown
+# notes: need to set 'results = "asis"' in the code chunk header
+view_results <- function(results_ls, metrics_plt_ls, rmetrics_plt_ls,
+ tpr_plt_ls, fi_bar_plt_ls, fi_box_plt_ls,
+ heatmap_plt_ls, vary_param_name_vec, abridged,
+ interactive = TRUE) {
+ cat(sprintf("\n\n# %s {.tabset .tabset-vmodern}\n\n",
+ basename(results_dir)))
+ if (is.null(vary_param_name_vec)) {
+ height <- 4
+ tpr_height <- height
+ } else {
+ height <- 8
+ tpr_height <- 4 * length(unique(results_ls[[paste0(vary_param_name_vec[2],
+ "_name")]]))
+ }
+
+ for (sim_name in names(results_ls)) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ cat(sprintf("\n\n## %s {.tabset .tabset-pills}\n\n", sim_name))
+
+ if (!abridged) {
+ cat(sprintf("\n\n### Tables\n\n"))
+ vthemes::subchunkify(vthemes::pretty_DT(results_ls[[sim_name]]),
+ i = chunk_idx)
+ chunk_idx <<- chunk_idx + 1
+ }
+
+ cat(sprintf("\n\n### Plots {.tabset .tabset-pills .tabset-square}\n\n"))
+ if (interactive) {
+ vthemes::subchunkify(plotly::ggplotly(metrics_plt_ls[[sim_name]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ fig_height = height,
+ add_class = "panel panel-default padded-panel")
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(rmetrics_plt_ls[[sim_name]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ fig_height = height,
+ add_class = "panel panel-default padded-panel")
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(tpr_plt_ls[[sim_name]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ fig_height = height,
+ add_class = "panel panel-default padded-panel")
+ } else {
+ vthemes::subchunkify(metrics_plt_ls[[sim_name]], i = chunk_idx,
+ fig_height = height)
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(rmetrics_plt_ls[[sim_name]], i = chunk_idx,
+ fig_height = height)
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(tpr_plt_ls[[sim_name]], i = chunk_idx,
+ fig_height = height)
+ }
+ chunk_idx <<- chunk_idx + 1
+
+ if (is.null(vary_param_name_vec)) {
+ for (param_val in names(heatmap_plt_ls[[sim_name]])) {
+ cat(sprintf("\n\n#### %s = %s\n\n\n", vary_param_name, param_val))
+ if (interactive) {
+ vthemes::subchunkify(plotly::ggplotly(heatmap_plt_ls[[sim_name]][[param_val]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(fi_box_plt_ls[[sim_name]][[param_val]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ if (!abridged) {
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(fi_bar_plt_ls[[sim_name]][[param_val]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ }
+ } else {
+ vthemes::subchunkify(heatmap_plt_ls[[sim_name]][[param_val]], i = chunk_idx)
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(fi_box_plt_ls[[sim_name]][[param_val]], i = chunk_idx)
+ if (!abridged) {
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(fi_bar_plt_ls[[sim_name]][[param_val]], i = chunk_idx)
+ }
+ }
+ chunk_idx <<- chunk_idx + 1
+ }
+ } else {
+ for (param_val1 in names(heatmap_plt_ls[[sim_name]])) {
+ cat(sprintf("\n\n#### %s = %s {.tabset .tabset-pills .tabset-circle}\n\n\n",
+ vary_param_name_vec[1], param_val1))
+ for (param_val2 in names(heatmap_plt_ls[[sim_name]][[param_val1]])) {
+ cat(sprintf("\n\n##### %s = %s\n\n\n",
+ vary_param_name_vec[2], param_val2))
+ if (interactive) {
+ vthemes::subchunkify(plotly::ggplotly(heatmap_plt_ls[[sim_name]][[param_val1]][[param_val2]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(fi_box_plt_ls[[sim_name]][[param_val1]][[param_val2]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ if (!abridged) {
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(plotly::ggplotly(fi_bar_plt_ls[[sim_name]][[param_val1]][[param_val2]]),
+ i = chunk_idx, other_args = "out.width = '100%'",
+ add_class = "panel panel-default padded-panel")
+ }
+ } else {
+ vthemes::subchunkify(heatmap_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx)
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(fi_box_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx)
+ if (!abridged) {
+ chunk_idx <<- chunk_idx + 1
+ vthemes::subchunkify(fi_bar_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx)
+ }
+ }
+ chunk_idx <<- chunk_idx + 1
+ }
+ }
+ }
+ }
+}
+```
+
+```{r}
+# read in results
+results_ls <- list()
+for (results_subdir in list.dirs(results_dir, full.names = T, recursive = F)) {
+ if (!is.null(vary_param_name)) {
+ if (!(results_subdir %in% file.path(results_dir,
+ paste0("varying_", vary_param_name)))) {
+ next
+ }
+ }
+ fname <- file.path(results_subdir, paste0("seed", seed), "results.csv")
+ if (file.exists(fname)) {
+ results_ls[[basename(results_subdir)]] <- data.table::fread(fname) %>%
+ reformat_results()
+ }
+}
+
+# plot evaluation metrics
+metrics_plt_ls <- list()
+for (sim_name in names(results_ls)) {
+ metrics <- intersect(colnames(results_ls[[sim_name]]), c("rocauc", "prauc"))
+ if (length(metrics) > 0) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ metrics_plt_ls[[sim_name]] <- plot_metrics(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec,
+ )
+ } else {
+ metrics_plt_ls[[sim_name]] <- NULL
+ }
+}
+
+# plot restricted evaluation metrics
+rmetrics_plt_ls <- list()
+for (sim_name in names(results_ls)) {
+ if (length(metrics) > 0) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ rmetrics_plt_ls[[sim_name]] <- plot_restricted_metrics(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec,
+ )
+ } else {
+ rmetrics_plt_ls[[sim_name]] <- NULL
+ }
+}
+
+# plot tpr
+tpr_plt_ls <- list()
+for (sim_name in names(results_ls)) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ tpr_plt_ls[[sim_name]] <- plot_tpr(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec
+ )
+}
+
+# plot feature importances
+fi_box_plt_ls <- list()
+fi_bar_plt_ls <- list()
+for (sim_name in names(results_ls)) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ fi_box_plt_ls[[sim_name]] <- plot_feature_importance(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars, plot_type = "boxplot"
+ )
+ fi_bar_plt_ls[[sim_name]] <- plot_feature_importance(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars, plot_type = "bar"
+ )
+}
+
+# plot heatmap
+heatmap_plt_ls <- list()
+for (sim_name in names(results_ls)) {
+ vary_param_name <- stringr::str_remove(sim_name, "^varying\\_")
+ heatmap_plt_ls[[sim_name]] <- plot_ranking_heatmap(
+ results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars
+ )
+}
+```
+
+```{r results = "asis"}
+# display plots nicely in knitted html document
+view_results(results_ls, metrics_plt_ls, rmetrics_plt_ls, tpr_plt_ls,
+ fi_bar_plt_ls, fi_box_plt_ls, heatmap_plt_ls,
+ vary_param_name_vec, abridged)
+```
+
diff --git a/feature_importance/savio/run_simulations_mdi_plus.sh b/feature_importance/savio/run_simulations_mdi_plus.sh
new file mode 100644
index 0000000..8bb29e6
--- /dev/null
+++ b/feature_importance/savio/run_simulations_mdi_plus.sh
@@ -0,0 +1,230 @@
+sims=(
+ # Regression: ccle_rnaseq
+ "mdi_plus.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_linear_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_lss_3m_2r_dgp"
+ # Regression: enhancer
+ "mdi_plus.regression_sims.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.enhancer_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.enhancer_linear_dgp"
+ "mdi_plus.regression_sims.enhancer_lss_3m_2r_dgp"
+ # Regression: juvenile
+ "mdi_plus.regression_sims.juvenile_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.juvenile_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.juvenile_linear_dgp"
+ "mdi_plus.regression_sims.juvenile_lss_3m_2r_dgp"
+ # Regression: splicing
+ "mdi_plus.regression_sims.splicing_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.splicing_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.splicing_linear_dgp"
+ "mdi_plus.regression_sims.splicing_lss_3m_2r_dgp"
+
+ # Classification: ccle_rnaseq
+ "mdi_plus.classification_sims.ccle_rnaseq_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.ccle_rnaseq_logistic_dgp"
+ "mdi_plus.classification_sims.ccle_rnaseq_linear_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.ccle_rnaseq_hier_poly_3m_2r_logistic_dgp"
+ # Classification: enhancer
+ "mdi_plus.classification_sims.enhancer_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.enhancer_logistic_dgp"
+ "mdi_plus.classification_sims.enhancer_linear_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.enhancer_hier_poly_3m_2r_logistic_dgp"
+ # Classification: juvenile
+ "mdi_plus.classification_sims.juvenile_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.juvenile_logistic_dgp"
+ "mdi_plus.classification_sims.juvenile_linear_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.juvenile_hier_poly_3m_2r_logistic_dgp"
+ # Classification: splicing
+ "mdi_plus.classification_sims.splicing_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.splicing_logistic_dgp"
+ "mdi_plus.classification_sims.splicing_linear_lss_3m_2r_logistic_dgp"
+ "mdi_plus.classification_sims.splicing_hier_poly_3m_2r_logistic_dgp"
+
+ # Robust: enhancer
+ "mdi_plus.robust_sims.enhancer_linear_10MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_linear_25MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_lss_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_lss_3m_2r_25MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_linear_lss_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_linear_lss_3m_2r_25MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_hier_poly_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.enhancer_hier_poly_3m_2r_25MS_robust_dgp"
+ # Robust: ccle_rnaseq
+ "mdi_plus.robust_sims.ccle_rnaseq_linear_10MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_linear_25MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_lss_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_lss_3m_2r_25MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp"
+ "mdi_plus.robust_sims.ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp"
+
+ # MDI bias simulations
+ "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
+ "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
+ "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
+
+ # Regression (varying number of features): ccle_rnaseq
+ "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_linear_lss_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_linear_dgp"
+ "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_lss_3m_2r_dgp"
+ # Regression (varying sparsity level): juvenile
+ "mdi_plus.other_regression_sims.varying_sparsity.juvenile_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.juvenile_linear_lss_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.juvenile_linear_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.juvenile_lss_3m_2r_dgp"
+ # Regression (varying sparsity level): splicing
+ "mdi_plus.other_regression_sims.varying_sparsity.splicing_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.splicing_linear_lss_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.splicing_linear_dgp"
+ "mdi_plus.other_regression_sims.varying_sparsity.splicing_lss_3m_2r_dgp"
+
+ # MDI+ Modeling Choices: Enhancer (min_samples_per_leaf = 5)
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples5.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples5.enhancer_linear_dgp"
+ # MDI+ Modeling Choices: CCLE (min_samples_per_leaf = 5)
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples5.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples5.ccle_rnaseq_linear_dgp"
+ # MDI+ Modeling Choices: Enhancer (min_samples_per_leaf = 1)
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples1.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples1.enhancer_linear_dgp"
+ # MDI+ Modeling Choices: CCLE (min_samples_per_leaf = 1)
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples1.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.other_regression_sims.modeling_choices_min_samples1.ccle_rnaseq_linear_dgp"
+
+ # MDI+ GLM and Metric Choices: Regression
+ "mdi_plus.glm_metric_choices_sims.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_sims.ccle_rnaseq_linear_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_sims.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_sims.enhancer_linear_dgp"
+ # MDI+ GLM and Metric Choices: Classification
+ "mdi_plus.glm_metric_choices_sims.classification_sims.juvenile_logistic_dgp"
+ "mdi_plus.glm_metric_choices_sims.classification_sims.juvenile_hier_poly_3m_2r_logistic_dgp"
+ "mdi_plus.glm_metric_choices_sims.classification_sims.splicing_logistic_dgp"
+ "mdi_plus.glm_metric_choices_sims.classification_sims.splicing_hier_poly_3m_2r_logistic_dgp"
+)
+
+for sim in "${sims[@]}"
+do
+ sbatch --job-name=${sim} submit_simulation_job.sh ${sim}
+done
+
+
+## Misspecified model simulations
+misspecifiedsims=(
+ # Misspecified Regression: ccle_rnaseq
+ "mdi_plus.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_linear_dgp"
+ "mdi_plus.regression_sims.ccle_rnaseq_lss_3m_2r_dgp"
+ # Misspecified Regression: enhancer
+ "mdi_plus.regression_sims.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.enhancer_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.enhancer_linear_dgp"
+ "mdi_plus.regression_sims.enhancer_lss_3m_2r_dgp"
+ # Misspecified Regression: juvenile
+ "mdi_plus.regression_sims.juvenile_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.juvenile_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.juvenile_linear_dgp"
+ "mdi_plus.regression_sims.juvenile_lss_3m_2r_dgp"
+ # Misspecified Regression: splicing
+ "mdi_plus.regression_sims.splicing_hier_poly_3m_2r_dgp"
+ "mdi_plus.regression_sims.splicing_linear_lss_3m_2r_dgp"
+ "mdi_plus.regression_sims.splicing_linear_dgp"
+ "mdi_plus.regression_sims.splicing_lss_3m_2r_dgp"
+)
+
+for sim in "${misspecifiedsims[@]}"
+do
+ sbatch --job-name=${sim}_omitted_vars submit_simulation_job_omitted_vars.sh ${sim}
+done
+
+
+## Real-data Case Study
+# CCLE RNASeq
+drugs=(
+ "17-AAG"
+ "AEW541"
+ "AZD0530"
+ "AZD6244"
+ "Erlotinib"
+ "Irinotecan"
+ "L-685458"
+ "LBW242"
+ "Lapatinib"
+ "Nilotinib"
+ "Nutlin-3"
+ "PD-0325901"
+ "PD-0332991"
+ "PF2341066"
+ "PHA-665752"
+ "PLX4720"
+ "Paclitaxel"
+ "Panobinostat"
+ "RAF265"
+ "Sorafenib"
+ "TAE684"
+ "TKI258"
+ "Topotecan"
+ "ZD-6474"
+)
+
+sim="mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
+for drug in "${drugs[@]}"
+do
+ sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_multitask.sh ${sim} "regression" ${drug}
+done
+
+sim="mdi_plus.real_data_case_study_no_data_split.ccle_rnaseq_regression-"
+for drug in "${drugs[@]}"
+do
+ sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_multitask.sh ${sim} "regression" ${drug}
+done
+
+# TCGA BRCA
+sim="mdi_plus.real_data_case_study.tcga_brca_classification-"
+sbatch --job-name=${sim} submit_simulation_job_real_data.sh ${sim} "multiclass_classification"
+
+sim="mdi_plus.real_data_case_study_no_data_split.tcga_brca_classification-"
+sbatch --job-name=${sim} submit_simulation_job_real_data.sh ${sim} "multiclass_classification"
+
+
+## Prediction Simulations
+
+# Real Data: Binary classification
+sims=(
+ "mdi_plus.prediction_sims.enhancer_classification-"
+ "mdi_plus.prediction_sims.juvenile_classification-"
+ "mdi_plus.prediction_sims.splicing_classification-"
+)
+
+for sim in "${sims[@]}"
+do
+ sbatch --job-name=${sim} submit_simulation_job_real_data_prediction.sh ${sim} "binary_classification"
+done
+
+# Real Data: Multi-class classification
+sim="mdi_plus.prediction_sims.tcga_brca_classification-"
+sbatch --job-name=${sim} submit_simulation_job_real_data_prediction.sh ${sim} "multiclass_classification"
+
+# Real Data: Regression
+sim="mdi_plus.prediction_sims.ccle_rnaseq_regression-"
+for drug in "${drugs[@]}"
+do
+ sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_prediction_multitask.sh ${sim} ${drug} "regression"
+done
+
+# MDI+ GLM and Metric Choices: Regression
+sims=(
+ "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.ccle_rnaseq_hier_poly_3m_2r_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.ccle_rnaseq_linear_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_hier_poly_3m_2r_dgp"
+ "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_linear_dgp"
+)
+
+for sim in "${sims[@]}"
+do
+ sbatch --job-name=${sim} submit_simulation_job_prediction.sh ${sim} "regression"
+done
\ No newline at end of file
diff --git a/feature_importance/savio/submit_simulation_job.sh b/feature_importance/savio/submit_simulation_job.sh
new file mode 100644
index 0000000..5421aed
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=12:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../01_run_importance_simulations.py --nreps 50 --config ${1} --split_seed 12345 --parallel --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_omitted_vars.sh b/feature_importance/savio/submit_simulation_job_omitted_vars.sh
new file mode 100644
index 0000000..737fd8d
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_omitted_vars.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=12:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../01_run_importance_simulations.py --nreps 50 --config ${1} --split_seed 12345 --omit_vars 0,1 --parallel --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_prediction.sh b/feature_importance/savio/submit_simulation_job_prediction.sh
new file mode 100644
index 0000000..f3c60ad
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_prediction.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=12:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../03_run_prediction_simulations.py --nreps 50 --config ${1} --split_seed 12345 --parallel --mode ${2} --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_real_data.sh b/feature_importance/savio/submit_simulation_job_real_data.sh
new file mode 100644
index 0000000..1fdd2b8
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_real_data.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=48:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../02_run_importance_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_real_data_multitask.sh b/feature_importance/savio/submit_simulation_job_real_data_multitask.sh
new file mode 100644
index 0000000..1a4ec9f
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_real_data_multitask.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=48:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../02_run_importance_real_data.py --nreps 32 --config ${1} --response_idx ${2} --split_seed 12345 --parallel --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_real_data_prediction.sh b/feature_importance/savio/submit_simulation_job_real_data_prediction.sh
new file mode 100644
index 0000000..b610ec5
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_real_data_prediction.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=24:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../04_run_prediction_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --mode ${2} --subsample_n 10000 --nosave_cols "prediction_model"
diff --git a/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh b/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh
new file mode 100644
index 0000000..5110468
--- /dev/null
+++ b/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#SBATCH --account=co_stat
+#SBATCH --partition=savio
+#SBATCH --time=24:00:00
+#
+#SBATCH --nodes=1
+
+module load python/3.7
+module load r
+
+source activate r2f
+
+python ../04_run_prediction_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --response_idx ${2} --mode ${3} --nosave_cols "prediction_model"
diff --git a/feature_importance/scripts/__init__.py b/feature_importance/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/feature_importance/scripts/competing_methods.py b/feature_importance/scripts/competing_methods.py
new file mode 100644
index 0000000..7a96079
--- /dev/null
+++ b/feature_importance/scripts/competing_methods.py
@@ -0,0 +1,236 @@
+import os
+import sys
+import pandas as pd
+import numpy as np
+import sklearn.base
+from sklearn.base import RegressorMixin, ClassifierMixin
+from functools import reduce
+
+import shap
+from imodels.importance.rf_plus import RandomForestPlusRegressor, RandomForestPlusClassifier
+from feature_importance.scripts.mdi_oob import MDI_OOB
+from feature_importance.scripts.mda import MDA
+
+
+def tree_mdi_plus_ensemble(X, y, fit, scoring_fns="auto", **kwargs):
+ """
+ Wrapper around MDI+ object to get feature importance scores
+
+ :param X: ndarray of shape (n_samples, n_features)
+ The covariate matrix. If a pd.DataFrame object is supplied, then
+ the column names are used in the output
+ :param y: ndarray of shape (n_samples, n_targets)
+ The observed responses.
+ :param rf_model: scikit-learn random forest object or None
+ The RF model to be used for interpretation. If None, then a new
+ RandomForestRegressor or RandomForestClassifier is instantiated.
+ :param kwargs: additional arguments to pass to
+ RandomForestPlusRegressor or RandomForestPlusClassifier class.
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: MDI+ score
+ """
+
+ if isinstance(fit, RegressorMixin):
+ RFPlus = RandomForestPlusRegressor
+ elif isinstance(fit, ClassifierMixin):
+ RFPlus = RandomForestPlusClassifier
+ else:
+ raise ValueError("Unknown task.")
+
+ mdi_plus_scores_dict = {}
+ for rf_plus_name, rf_plus_args in kwargs.items():
+ rf_plus_model = RFPlus(rf_model=fit, **rf_plus_args)
+ rf_plus_model.fit(X, y)
+ try:
+ mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns)
+ except ValueError as e:
+ if str(e) == 'Transformer representation was empty for all trees.':
+ mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
+ if isinstance(X, pd.DataFrame):
+ mdi_plus_scores.index = X.columns
+ mdi_plus_scores.index.name = 'var'
+ mdi_plus_scores.reset_index(inplace=True)
+ else:
+ raise
+ for col in mdi_plus_scores.columns:
+ if col != "var":
+ mdi_plus_scores = mdi_plus_scores.rename(columns={col: col + "_" + rf_plus_name})
+ mdi_plus_scores_dict[rf_plus_name] = mdi_plus_scores
+
+ mdi_plus_scores_df = pd.concat([df.set_index('var') for df in mdi_plus_scores_dict.values()], axis=1)
+ mdi_plus_ranks_df = mdi_plus_scores_df.rank(ascending=False).median(axis=1)
+ mdi_plus_ranks_df = pd.DataFrame(mdi_plus_ranks_df, columns=["importance"]).reset_index()
+
+ return mdi_plus_ranks_df
+
+
+def tree_mdi_plus(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs):
+ """
+ Wrapper around MDI+ object to get feature importance scores
+
+ :param X: ndarray of shape (n_samples, n_features)
+ The covariate matrix. If a pd.DataFrame object is supplied, then
+ the column names are used in the output
+ :param y: ndarray of shape (n_samples, n_targets)
+ The observed responses.
+ :param rf_model: scikit-learn random forest object or None
+ The RF model to be used for interpretation. If None, then a new
+ RandomForestRegressor or RandomForestClassifier is instantiated.
+ :param kwargs: additional arguments to pass to
+ RandomForestPlusRegressor or RandomForestPlusClassifier class.
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: MDI+ score
+ """
+
+ if isinstance(fit, RegressorMixin):
+ RFPlus = RandomForestPlusRegressor
+ elif isinstance(fit, ClassifierMixin):
+ RFPlus = RandomForestPlusClassifier
+ else:
+ raise ValueError("Unknown task.")
+ rf_plus_model = RFPlus(rf_model=fit, **kwargs)
+ rf_plus_model.fit(X, y)
+ try:
+ mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns)
+ if return_stability_scores:
+ stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
+ except ValueError as e:
+ if str(e) == 'Transformer representation was empty for all trees.':
+ mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
+ if isinstance(X, pd.DataFrame):
+ mdi_plus_scores.index = X.columns
+ mdi_plus_scores.index.name = 'var'
+ mdi_plus_scores.reset_index(inplace=True)
+ stability_scores = None
+ else:
+ raise
+ mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
+ if return_stability_scores:
+ mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
+
+ return mdi_plus_scores
+
+
+def tree_mdi(X, y, fit, include_num_splits=False):
+ """
+ Extract MDI values for a given tree
+ OR
+ Average MDI values for a given random forest
+ :param X: design matrix
+ :param y: response
+ :param fit: fitted model of interest
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: MDI or avg MDI
+ """
+ av_splits = get_num_splits(X, y, fit)
+ results = fit.feature_importances_
+ results = pd.DataFrame(data=results, columns=['importance'])
+
+ # Use column names from dataframe if possible
+ if isinstance(X, pd.DataFrame):
+ results.index = X.columns
+ results.index.name = 'var'
+ results.reset_index(inplace=True)
+
+ if include_num_splits:
+ results['av_splits'] = av_splits
+
+ return results
+
+
+def tree_mdi_OOB(X, y, fit, type='oob',
+ normalized=False, balanced=False, demean=False, normal_fX=False):
+ """
+ Compute MDI-oob feature importance for a given random forest
+ :param X: design matrix
+ :param y: response
+ :param fit: fitted model of interest
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: MDI-oob
+ """
+ reshaped_y = y.reshape((len(y), 1))
+ results = MDI_OOB(fit, X, reshaped_y, type=type, normalized=normalized, balanced=balanced,
+ demean=demean, normal_fX=normal_fX)[0]
+ results = pd.DataFrame(data=results, columns=['importance'])
+ if isinstance(X, pd.DataFrame):
+ results.index = X.columns
+ results.index.name = 'var'
+ results.reset_index(inplace=True)
+
+ return results
+
+
+def tree_shap(X, y, fit):
+ """
+ Compute average treeshap value across observations
+ :param X: design matrix
+ :param y: response
+ :param fit: fitted model of interest (tree-based)
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: average absolute shap value
+ """
+ explainer = shap.TreeExplainer(fit)
+ shap_values = explainer.shap_values(X, check_additivity=False)
+ if sklearn.base.is_classifier(fit):
+ def add_abs(a, b):
+ return abs(a) + abs(b)
+ results = reduce(add_abs, shap_values)
+ else:
+ results = abs(shap_values)
+ results = results.mean(axis=0)
+ results = pd.DataFrame(data=results, columns=['importance'])
+ # Use column names from dataframe if possible
+ if isinstance(X, pd.DataFrame):
+ results.index = X.columns
+ results.index.name = 'var'
+ results.reset_index(inplace=True)
+
+ return results
+
+
+def tree_mda(X, y, fit, type="oob", n_repeats=10, metric="auto"):
+ """
+ Compute MDA importance for a given random forest
+ :param X: design matrix
+ :param y: response
+ :param fit: fitted model of interest
+ :param type: "oob" or "train"
+ :param n_repeats: number of permutations
+ :param metric: metric for computation MDA/permutation importance
+ :return: dataframe - [Var, Importance]
+ Var: variable name
+ Importance: MDA
+ """
+ if metric == "auto":
+ if isinstance(y[0], str):
+ metric = "accuracy"
+ else:
+ metric = "mse"
+
+ results, _ = MDA(fit, X, y[:, np.newaxis], type=type, n_trials=n_repeats, metric=metric)
+ results = pd.DataFrame(data=results, columns=['importance'])
+
+ if isinstance(X, pd.DataFrame):
+ results.index = X.columns
+ results.index.name = 'var'
+ results.reset_index(inplace=True)
+
+ return results
+
+
+def get_num_splits(X, y, fit):
+ """
+ Gets number of splits per feature in a fitted RF
+ """
+ num_splits_feature = np.zeros(X.shape[1])
+ for tree in fit.estimators_:
+ tree_features = tree.tree_.feature
+ for i in range(X.shape[1]):
+ num_splits_feature[i] += np.count_nonzero(tree_features == i)
+ num_splits_feature/len(fit.estimators_)
+ return num_splits_feature
diff --git a/feature_importance/scripts/data_cleaning_mdi_plus.R b/feature_importance/scripts/data_cleaning_mdi_plus.R
new file mode 100644
index 0000000..a932cbf
--- /dev/null
+++ b/feature_importance/scripts/data_cleaning_mdi_plus.R
@@ -0,0 +1,65 @@
+#### Setup ####
+library(magrittr)
+if (!require("vdocs")) devtools::install_github("Yu-Group/vdocs")
+
+data_dir <- file.path("..", "data")
+
+### Remove duplicate TFs in Enhancer ####
+load(file.path(data_dir, "enhancer.Rdata"))
+
+keep_vars <- varnames.all %>%
+ dplyr::group_by(Predictor_collapsed) %>%
+ dplyr::mutate(id = 1:dplyr::n()) %>%
+ dplyr::filter(id == 1)
+write.csv(
+ X[, keep_vars$Predictor],
+ file.path(data_dir, "X_enhancer.csv"),
+ row.names = FALSE
+)
+
+#### Clean covariate matrices ####
+X_paths <- c("X_juvenile.csv",
+ "X_splicing.csv",
+ "X_ccle_rnaseq.csv",
+ "X_enhancer.csv")
+log_transform <- c("X_splicing.csv",
+ "X_ccle_rnaseq.csv",
+ "X_enhancer.csv")
+
+for (X_path in X_paths) {
+ X_orig <- data.table::fread(file.path(data_dir, X_path)) %>%
+ tibble::as_tibble()
+
+ # dim(X_orig)
+ # sum(is.na(X_orig))
+
+ X <- X_orig %>%
+ vdocs::remove_constant_cols(verbose = 2) %>%
+ vdocs::remove_duplicate_cols(verbose = 2)
+
+ # hist(as.matrix(X))
+ if (X_path %in% log_transform) {
+ X <- log(X + 1)
+ # hist(as.matrix(X))
+ }
+
+ # dim(X)
+
+ write.csv(
+ X,
+ file.path(data_dir, "%s_cleaned.csv", fs::path_ext_remove(X_path)),
+ row.names = FALSE
+ )
+}
+
+#### Filter number of features for real data case study ####
+X_orig <- data.table::fread(file.path(data_dir, "X_ccle_rnaseq_cleaned.csv"))
+
+X <- X_orig %>%
+ vdocs::filter_cols_by_var(max_p = 5000)
+
+write.csv(
+ X,
+ file.path(data_dir, "X_ccle_rnaseq_cleaned_filtered5000.csv"),
+ row.names = FALSE
+)
diff --git a/feature_importance/scripts/mda.py b/feature_importance/scripts/mda.py
new file mode 100644
index 0000000..e0c59fa
--- /dev/null
+++ b/feature_importance/scripts/mda.py
@@ -0,0 +1,69 @@
+from sklearn.ensemble._forest import _generate_unsampled_indices, _generate_sample_indices
+from sklearn.metrics import accuracy_score, mean_squared_error
+import numpy as np
+import copy
+from sklearn.preprocessing import LabelEncoder
+
+
+def MDA(rf, X, y, type = 'oob', n_trials = 10, metric = 'accuracy'):
+ if len(y.shape) != 2:
+ raise ValueError('y must be 2d array (n_samples, 1) if numerical or (n_samples, n_categories).')
+
+ y_mda = copy.deepcopy(y)
+ if rf._estimator_type == "classifier" and y.dtype == "object":
+ y_mda = LabelEncoder().fit(y_mda.ravel()).transform(y_mda.ravel()).reshape(y_mda.shape[0], 1)
+
+ n_samples, n_features = X.shape
+ fi_mean = np.zeros((n_features,))
+ fi_std = np.zeros((n_features,))
+ best_score = rf_accuracy(rf, X, y_mda, type = type, metric = metric)
+ for f in range(n_features):
+ permute_score = 0
+ permute_std = 0
+ X_permute = X.copy()
+ for i in range(n_trials):
+ X_permute[:, f] = np.random.permutation(X_permute[:, f])
+ to_add = rf_accuracy(rf, X_permute, y_mda, type = type, metric = metric)
+ permute_score += to_add
+ permute_std += to_add ** 2
+ permute_score /= n_trials
+ permute_std /= n_trials
+ permute_std = (permute_std - permute_score ** 2) ** .5 / n_trials ** .5
+ fi_mean[f] = best_score - permute_score
+ fi_std[f] = permute_std
+ return fi_mean, fi_std
+
+
+def neg_mse(y, y_hat):
+ return - mean_squared_error(y, y_hat)
+
+
+def rf_accuracy(rf, X, y, type = 'oob', metric = 'accuracy'):
+ if metric == 'accuracy':
+ score = accuracy_score
+ elif metric == 'mse':
+ score = neg_mse
+ else:
+ raise ValueError('metric type not understood')
+
+ n_samples, n_features = X.shape
+ tmp = 0
+ count = 0
+ if type == 'test':
+ return score(y, rf.predict(X))
+ elif type == 'train' and not rf.bootstrap:
+ return score(y, rf.predict(X))
+
+ for tree in rf.estimators_:
+ if type == 'oob':
+ if rf.bootstrap:
+ indices = _generate_unsampled_indices(tree.random_state, n_samples, n_samples)
+ else:
+ raise ValueError('Without bootstrap, it is not possible to calculate oob.')
+ elif type == 'train':
+ indices = _generate_sample_indices(tree.random_state, n_samples, n_samples)
+ else:
+ raise ValueError('type is not recognized. (%s)'%(type))
+ tmp += score(y[indices,:], tree.predict(X[indices, :])) * len(indices)
+ count += len(indices)
+ return tmp / count
\ No newline at end of file
diff --git a/feature_importance/scripts/mdi_oob.py b/feature_importance/scripts/mdi_oob.py
new file mode 100644
index 0000000..7c95228
--- /dev/null
+++ b/feature_importance/scripts/mdi_oob.py
@@ -0,0 +1,282 @@
+import numpy as np
+import warnings
+import sklearn
+
+from sklearn.ensemble._forest import ForestClassifier, ForestRegressor
+from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, _tree
+from distutils.version import LooseVersion
+from sklearn.ensemble._forest import _generate_unsampled_indices, _generate_sample_indices
+from sklearn.metrics import accuracy_score, mean_squared_error
+from sklearn.preprocessing import scale
+
+
+def MDI_OOB(rf, X, y, type='oob', normalized=False, balanced=False, demean=False, normal_fX=False):
+ n_samples, n_features = X.shape
+ if len(y.shape) != 2:
+ raise ValueError('y must be 2d array (n_samples, 1) if numerical or (n_samples, n_categories).')
+ out = np.zeros((n_features,))
+ SE = np.zeros((n_features,))
+ if demean:
+ # demean y
+ y = y - np.mean(y, axis=0)
+
+ for tree in rf.estimators_:
+ if type == 'oob':
+ if rf.bootstrap:
+ indices = _generate_unsampled_indices(tree.random_state, n_samples, n_samples)
+ else:
+ raise ValueError('Without bootstrap, it is not possible to calculate oob.')
+ elif type == 'test':
+ indices = np.arange(n_samples)
+ elif type == 'classic':
+ if rf.bootstrap:
+ indices = _generate_sample_indices(tree.random_state, n_samples, n_samples)
+ else:
+ indices = np.arange(n_samples)
+ else:
+ raise ValueError('type is not recognized. (%s)'%(type))
+ _, _, contributions = _predict_tree(tree, X[indices, :])
+ if balanced and (type == 'oob' or type == 'test'):
+ base_indices = _generate_sample_indices(tree.random_state, n_samples, n_samples)
+ ids = tree.apply(X[indices, :])
+ base_ids = tree.apply(X[base_indices, :])
+ tmp1, tmp2 = np.unique(ids, return_counts=True)
+ weight1 = {key: 1. / value for key, value in zip(tmp1, tmp2)}
+ tmp1, tmp2 = np.unique(base_ids, return_counts=True)
+ weight2 = {key: value for key, value in zip(tmp1, tmp2)}
+ final_weights = np.array([[weight1[id] * weight2[id]] for id in ids])
+ final_weights /= np.mean(final_weights)
+ else:
+ final_weights = 1
+ if len(contributions.shape) == 2:
+ contributions = contributions[:, :, np.newaxis]
+ # print(contributions.shape, y[indices,:].shape)
+ if normal_fX:
+ for k in range(contributions.shape[-1]):
+ contributions[:, :, k] = scale(contributions[:, :, k])
+ if contributions.shape[2] == 2:
+ contributions = contributions[:, :, 1:]
+ elif contributions.shape[2] > 2:
+ raise ValueError('Multi-class y is not currently supported.')
+ tmp = np.tensordot(np.array(y[indices, :]) * final_weights, contributions, axes=([0, 1], [0, 2]))
+ if normalized:
+ # if sum(tmp) != 0:
+ # out += tmp / sum(tmp)
+ out += tmp / sum(tmp)
+ else:
+ out += tmp / len(indices)
+ if normalized:
+ # if sum(tmp) != 0:
+ # SE += (tmp / sum(tmp)) ** 2
+ SE += (tmp / sum(tmp)) ** 2
+ else:
+ SE += (tmp / len(indices)) ** 2
+ out /= rf.n_estimators
+ SE /= rf.n_estimators
+ SE = ((SE - out ** 2) / rf.n_estimators) ** .5
+ return out, SE
+
+
+def _get_tree_paths(tree, node_id, depth=0):
+ """
+ Returns all paths through the tree as list of node_ids
+ """
+ if node_id == _tree.TREE_LEAF:
+ raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
+
+ left_child = tree.children_left[node_id]
+ right_child = tree.children_right[node_id]
+
+ if left_child != _tree.TREE_LEAF:
+ left_paths = _get_tree_paths(tree, left_child, depth=depth + 1)
+ right_paths = _get_tree_paths(tree, right_child, depth=depth + 1)
+
+ for path in left_paths:
+ path.append(node_id)
+ for path in right_paths:
+ path.append(node_id)
+ paths = left_paths + right_paths
+ else:
+ paths = [[node_id]]
+ return paths
+
+
+def _predict_tree(model, X, joint_contribution=False):
+ """
+ For a given DecisionTreeRegressor, DecisionTreeClassifier,
+ ExtraTreeRegressor, or ExtraTreeClassifier,
+ returns a triple of [prediction, bias and feature_contributions], such
+ that prediction ≈ bias + feature_contributions.
+ """
+ leaves = model.apply(X)
+ paths = _get_tree_paths(model.tree_, 0)
+
+ for path in paths:
+ path.reverse()
+
+ leaf_to_path = {}
+ # map leaves to paths
+ for path in paths:
+ leaf_to_path[path[-1]] = path
+
+ # remove the single-dimensional inner arrays
+ values = model.tree_.value.squeeze(axis=1)
+ # reshape if squeezed into a single float
+ if len(values.shape) == 0:
+ values = np.array([values])
+ if isinstance(model, DecisionTreeRegressor):
+ biases = np.full(X.shape[0], values[paths[0][0]])
+ line_shape = X.shape[1]
+ elif isinstance(model, DecisionTreeClassifier):
+ # scikit stores category counts, we turn them into probabilities
+ normalizer = values.sum(axis=1)[:, np.newaxis]
+ normalizer[normalizer == 0.0] = 1.0
+ values /= normalizer
+
+ biases = np.tile(values[paths[0][0]], (X.shape[0], 1))
+ line_shape = (X.shape[1], model.n_classes_)
+ else:
+ warnings.warn('the instance is not recognized. Try to proceed with classifier but could fail.')
+ normalizer = values.sum(axis=1)[:, np.newaxis]
+ normalizer[normalizer == 0.0] = 1.0
+ values /= normalizer
+
+ biases = np.tile(values[paths[0][0]], (X.shape[0], 1))
+ line_shape = (X.shape[1], model.n_classes_)
+
+ direct_prediction = values[leaves]
+
+ # make into python list, accessing values will be faster
+ values_list = list(values)
+ feature_index = list(model.tree_.feature)
+
+ contributions = []
+ if joint_contribution:
+ for row, leaf in enumerate(leaves):
+ path = leaf_to_path[leaf]
+
+ path_features = set()
+ contributions.append({})
+ for i in range(len(path) - 1):
+ path_features.add(feature_index[path[i]])
+ contrib = values_list[path[i + 1]] - \
+ values_list[path[i]]
+ # path_features.sort()
+ contributions[row][tuple(sorted(path_features))] = \
+ contributions[row].get(tuple(sorted(path_features)), 0) + contrib
+ return direct_prediction, biases, contributions
+
+ else:
+ unique_leaves = np.unique(leaves)
+ unique_contributions = {}
+
+ for row, leaf in enumerate(unique_leaves):
+ for path in paths:
+ if leaf == path[-1]:
+ break
+
+ contribs = np.zeros(line_shape)
+ for i in range(len(path) - 1):
+ contrib = values_list[path[i + 1]] - \
+ values_list[path[i]]
+ contribs[feature_index[path[i]]] += contrib
+ unique_contributions[leaf] = contribs
+
+ for row, leaf in enumerate(leaves):
+ contributions.append(unique_contributions[leaf])
+
+ return direct_prediction, biases, np.array(contributions)
+
+
+def _predict_forest(model, X, joint_contribution=False):
+ """
+ For a given RandomForestRegressor, RandomForestClassifier,
+ ExtraTreesRegressor, or ExtraTreesClassifier returns a triple of
+ [prediction, bias and feature_contributions], such that prediction ≈ bias +
+ feature_contributions.
+ """
+ biases = []
+ contributions = []
+ predictions = []
+
+ if joint_contribution:
+
+ for tree in model.estimators_:
+ pred, bias, contribution = _predict_tree(tree, X, joint_contribution=joint_contribution)
+
+ biases.append(bias)
+ contributions.append(contribution)
+ predictions.append(pred)
+
+ total_contributions = []
+
+ for i in range(len(X)):
+ contr = {}
+ for j, dct in enumerate(contributions):
+ for k in set(dct[i]).union(set(contr.keys())):
+ contr[k] = (contr.get(k, 0) * j + dct[i].get(k, 0)) / (j + 1)
+
+ total_contributions.append(contr)
+
+ for i, item in enumerate(contribution):
+ total_contributions[i]
+ sm = sum([v for v in contribution[i].values()])
+
+ return (np.mean(predictions, axis=0), np.mean(biases, axis=0),
+ total_contributions)
+ else:
+ for tree in model.estimators_:
+ pred, bias, contribution = _predict_tree(tree, X)
+
+ biases.append(bias)
+ contributions.append(contribution)
+ predictions.append(pred)
+
+ return (np.mean(predictions, axis=0), np.mean(biases, axis=0),
+ np.mean(contributions, axis=0))
+
+
+def predict(model, X, joint_contribution=False):
+ """ Returns a triple (prediction, bias, feature_contributions), such
+ that prediction ≈ bias + feature_contributions.
+ Parameters
+ ----------
+ model : DecisionTreeRegressor, DecisionTreeClassifier,
+ ExtraTreeRegressor, ExtraTreeClassifier,
+ RandomForestRegressor, RandomForestClassifier,
+ ExtraTreesRegressor, ExtraTreesClassifier
+ Scikit-learn model on which the prediction should be decomposed.
+ X : array-like, shape = (n_samples, n_features)
+ Test samples.
+
+ joint_contribution : boolean
+ Specifies if contributions are given individually from each feature,
+ or jointly over them
+ Returns
+ -------
+ decomposed prediction : triple of
+ * prediction, shape = (n_samples) for regression and (n_samples, n_classes)
+ for classification
+ * bias, shape = (n_samples) for regression and (n_samples, n_classes) for
+ classification
+ * contributions, If joint_contribution is False then returns and array of
+ shape = (n_samples, n_features) for regression or
+ shape = (n_samples, n_features, n_classes) for classification, denoting
+ contribution from each feature.
+ If joint_contribution is True, then shape is array of size n_samples,
+ where each array element is a dict from a tuple of feature indices to
+ to a value denoting the contribution from that feature tuple.
+ """
+ # Only single out response variable supported,
+ if model.n_outputs_ > 1:
+ raise ValueError("Multilabel classification trees not supported")
+
+ if (isinstance(model, DecisionTreeClassifier) or
+ isinstance(model, DecisionTreeRegressor)):
+ return _predict_tree(model, X, joint_contribution=joint_contribution)
+ elif (isinstance(model, ForestClassifier) or
+ isinstance(model, ForestRegressor)):
+ return _predict_forest(model, X, joint_contribution=joint_contribution)
+ else:
+ raise ValueError("Wrong model type. Base learner needs to be a "
+ "DecisionTreeClassifier or DecisionTreeRegressor.")
diff --git a/feature_importance/scripts/simulations_util.py b/feature_importance/scripts/simulations_util.py
new file mode 100644
index 0000000..fab8027
--- /dev/null
+++ b/feature_importance/scripts/simulations_util.py
@@ -0,0 +1,1019 @@
+import numpy as np
+import pandas as pd
+import random
+from scipy.linalg import toeplitz
+import warnings
+import math
+
+
+def sample_real_X(fpath=None, X=None, seed=None, normalize=True,
+ sample_row_n=None, sample_col_n=None, permute_col=True,
+ signal_features=None, n_signal_features=None, permute_nonsignal_col=None):
+ """
+ :param fpath: path to X data
+ :param X: data matrix
+ :param seed: random seed
+ :param normalize: boolean; whether or not to normalize columns in data to mean 0 and variance 1
+ :param sample_row_n: number of samples to subset; default keeps all rows
+ :param sample_col_n: number of features to subset; default keeps all columns
+ :param permute_col: boolean; whether or not to permute the columns
+ :param signal_features: list of features to use as signal features
+ :param n_signal_features: number of signal features; required if permute_nonsignal_col is not None
+ :param permute_nonsignal_col: how to permute the nonsignal features; must be one of
+ [None, "block", "indep", "augment"], where None performs no permutation, "block" performs the permutation
+ row-wise, "indep" permutes each nonsignal feature column independently, "augment" augments the signal features
+ with the row-permuted X matrix.
+ :return:
+ """
+ assert permute_nonsignal_col in [None, "block", "indep", "augment"]
+ if X is None:
+ X = pd.read_csv(fpath)
+ if normalize:
+ X = (X - X.mean()) / X.std()
+ if seed is not None:
+ np.random.seed(seed)
+ if permute_col:
+ X = X[np.random.permutation(X.columns)]
+ if sample_row_n is not None:
+ keep_idx = np.random.choice(X.shape[0], sample_row_n, replace=False)
+ X = X.iloc[keep_idx, :]
+ if sample_col_n is not None:
+ if signal_features is None:
+ X = X.sample(n=sample_col_n, replace=False, axis=1)
+ else:
+ rand_features = np.random.choice([col for col in X.columns if col not in signal_features],
+ sample_col_n - len(signal_features), replace=False)
+ X = X[signal_features + list(rand_features)]
+ if signal_features is not None:
+ X = X[signal_features + [col for col in X.columns if col not in signal_features]]
+ if permute_nonsignal_col is not None:
+ assert n_signal_features is not None
+ if permute_nonsignal_col == "block":
+ X = np.hstack([X.iloc[:, :n_signal_features].to_numpy(),
+ X.iloc[np.random.permutation(X.shape[0]), n_signal_features:].to_numpy()])
+ X = pd.DataFrame(X)
+ elif permute_nonsignal_col == "indep":
+ for j in range(n_signal_features, X.shape[1]):
+ X.iloc[:, j] = np.random.permutation(X.iloc[:, j])
+ elif permute_nonsignal_col == "augment":
+ X = np.hstack([X.iloc[:, :n_signal_features].to_numpy(),
+ X.iloc[np.random.permutation(X.shape[0]), :].to_numpy()])
+ X = IndexedArray(pd.DataFrame(X).to_numpy(), index=keep_idx)
+ return X
+ return X.to_numpy()
+
+
+def sample_normal_X(n, d, mean=0, scale=1, corr=0, Sigma=None):
+ """
+ Sample X with iid normal entries
+ :param n:
+ :param d:
+ :param mean:
+ :param scale:
+ :param corr:
+ :param Sigma:
+ :return:
+ """
+ if Sigma is not None:
+ if np.isscalar(mean):
+ mean = np.repeat(mean, d)
+ X = np.random.multivariate_normal(mean, Sigma, size=n)
+ elif corr == 0:
+ X = np.random.normal(mean, scale, size=(n, d))
+ else:
+ Sigma = np.zeros((d, d)) + corr
+ np.fill_diagonal(Sigma, 1)
+ if np.isscalar(mean):
+ mean = np.repeat(mean, d)
+ X = np.random.multivariate_normal(mean, Sigma, size=n)
+ return X
+
+
+def sample_block_cor_X(n, d, rho, n_blocks, mean=0):
+ """
+ Sample X from N(mean, Sigma) where Sigma is a block diagonal covariance matrix
+ :param n: number of samples
+ :param d: number of features
+ :param rho: correlation or vector of correlations
+ :param n_blocks: number of blocks
+ :param mean: mean of normal distribution
+ :return:
+ """
+ Sigma = np.zeros((d, d))
+ block_size = d // n_blocks
+ if np.isscalar(rho):
+ rho = np.repeat(rho, n_blocks)
+ for i in range(n_blocks):
+ start = i * block_size
+ end = (i + 1) * block_size
+ if i == (n_blocks - 1):
+ end = d
+ Sigma[start:end, start:end] = rho[i]
+ np.fill_diagonal(Sigma, 1)
+ X = sample_normal_X(n=n, d=d, mean=mean, Sigma=Sigma)
+ return X
+
+
+def sample_X(support, X_fun, **kwargs):
+ """
+ Wrapper around dgp function for X that reorders columns so support features are in front
+ :param support:
+ :param X_fun:
+ :param kwargs:
+ :return:
+ """
+ X = X_fun(**kwargs)
+ for i in range(X.shape[1]):
+ if i not in support:
+ support.append(i)
+ X[:] = X[:, support]
+ return X
+
+
+def generate_coef(beta, s):
+ if isinstance(beta, int) or isinstance(beta, float):
+ beta = np.repeat(beta, repeats=s)
+ return beta
+
+
+def corrupt_leverage(x_train, y_train, mean_shift, corrupt_quantile, mode="normal"):
+ assert mode in ["normal", "constant"]
+ if mean_shift is None:
+ return y_train
+ ranked_rows = np.apply_along_axis(np.linalg.norm, axis=1, arr=x_train).argsort().argsort()
+ low_idx = np.where(ranked_rows < round(corrupt_quantile * len(y_train)))[0]
+ hi_idx = np.where(ranked_rows >= (len(y_train) - round(corrupt_quantile * len(y_train))))[0]
+ if mode == "normal":
+ hi_corrupted = np.random.normal(mean_shift, 1, size=len(hi_idx))
+ low_corrupted = np.random.normal(-mean_shift, 1, size=len(low_idx))
+ elif mode == "constant":
+ hi_corrupted = mean_shift
+ low_corrupted = -mean_shift
+ y_train[hi_idx] = hi_corrupted
+ y_train[low_idx] = low_corrupted
+ return y_train
+
+
+def linear_model(X, sigma, s, beta, heritability=None, snr=None, error_fun=None,
+ frac_corrupt=None, corrupt_how='permute', corrupt_size=None,
+ corrupt_mean=None, return_support=False):
+ """
+ This method is used to crete responses from a linear model with hard sparsity
+ Parameters:
+ X: X matrix
+ s: sparsity
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+ sigma: s.d. of added noise
+ Returns:
+ numpy array of shape (n)
+ """
+ n, p = X.shape
+ def create_y(x, s, beta):
+ linear_term = 0
+ for j in range(s):
+ linear_term += x[j] * beta[j]
+ return linear_term
+
+ beta = generate_coef(beta, s)
+ y_train = np.array([create_y(X[i, :], s, beta) for i in range(len(X))])
+ if heritability is not None:
+ sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5
+ if snr is not None:
+ sigma = (np.var(y_train) / snr) ** 0.5
+ if error_fun is None:
+ error_fun = np.random.randn
+ if frac_corrupt is None and corrupt_size is None:
+ y_train = y_train + sigma * error_fun(n)
+ else:
+ if frac_corrupt is None:
+ frac_corrupt = 0
+ num_corrupt = int(np.floor(frac_corrupt*len(y_train)))
+ corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt)
+ if corrupt_how == 'permute':
+ corrupt_array = y_train[corrupt_indices]
+ corrupt_array = random.sample(list(corrupt_array), len(corrupt_array))
+ for i,index in enumerate(corrupt_indices):
+ y_train[index] = corrupt_array[i]
+ y_train = y_train + sigma * error_fun(n)
+ elif corrupt_how == 'cauchy':
+ for i in range(len(y_train)):
+ if i in corrupt_indices:
+ y_train[i] = y_train[i] + sigma*np.random.standard_cauchy()
+ else:
+ y_train[i] = y_train[i] + sigma*error_fun()
+ elif corrupt_how == "leverage_constant":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(s, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant")
+ elif corrupt_how == "leverage_normal":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(s, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal")
+
+ if return_support:
+ support = np.concatenate((np.ones(s), np.zeros(X.shape[1] - s)))
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def lss_model(X, sigma, m, r, tau, beta, heritability=None, snr=None, error_fun=None, min_active=None,
+ frac_corrupt=None, corrupt_how='permute', corrupt_size=None, corrupt_mean=None,
+ return_support=False):
+ """
+ This method creates response from an LSS model
+
+ X: data matrix
+ m: number of interaction terms
+ r: max order of interaction
+ tau: threshold
+ sigma: standard deviation of noise
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+
+ :return
+ y_train: numpy array of shape (n)
+ """
+ n, p = X.shape
+ assert p >= m * r # Cannot have more interactions * size than the dimension
+
+ def lss_func(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ for j in range(m):
+ lss_term_components = x_bool[j * r:j * r + r]
+ lss_term = int(all(lss_term_components))
+ y += lss_term * beta[j]
+ return y
+
+ def lss_vector_fun(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ max_iter = 100
+ features = np.arange(p)
+ support_idx = []
+ for j in range(m):
+ cnt = 0
+ while True:
+ int_features = np.random.choice(features, size=r, replace=False)
+ lss_term_components = x_bool[:, int_features]
+ lss_term = np.apply_along_axis(all, 1, lss_term_components)
+ cnt += 1
+ if np.mean(lss_term) >= min_active or cnt > max_iter:
+ y += lss_term * beta[j]
+ features = list(set(features).difference(set(int_features)))
+ support_idx.append(int_features)
+ if cnt > max_iter:
+ warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active))
+ break
+ support_idx = np.stack(support_idx).ravel()
+ support = np.zeros(p)
+ for j in support_idx:
+ support[j] = 1
+ return y, support
+
+ beta = generate_coef(beta, m)
+ if tau == 'median':
+ tau = np.median(X,axis = 0)
+
+ if min_active is None:
+ y_train = np.array([lss_func(X[i, :], beta) for i in range(n)])
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+ else:
+ y_train, support = lss_vector_fun(X, beta)
+
+ if heritability is not None:
+ sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5
+ if snr is not None:
+ sigma = (np.var(y_train) / snr) ** 0.5
+ if error_fun is None:
+ error_fun = np.random.randn
+
+ if frac_corrupt is None and corrupt_size is None:
+ y_train = y_train + sigma * error_fun(n)
+ else:
+ if frac_corrupt is None:
+ frac_corrupt = 0
+ num_corrupt = int(np.floor(frac_corrupt*len(y_train)))
+ corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt)
+ if corrupt_how == 'permute':
+ corrupt_array = y_train[corrupt_indices]
+ corrupt_array = random.sample(list(corrupt_array), len(corrupt_array))
+ for i,index in enumerate(corrupt_indices):
+ y_train[index] = corrupt_array[i]
+ y_train = y_train + sigma * error_fun(n)
+ elif corrupt_how == 'cauchy':
+ for i in range(len(y_train)):
+ if i in corrupt_indices:
+ y_train[i] = y_train[i] + sigma*np.random.standard_cauchy()
+ else:
+ y_train[i] = y_train[i] + sigma*error_fun()
+ elif corrupt_how == "leverage_constant":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(m*r, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant")
+ elif corrupt_how == "leverage_normal":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(m*r, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal")
+
+ if return_support:
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def partial_linear_lss_model(X, sigma, s, m, r, tau, beta, heritability=None, snr=None, error_fun=None,
+ min_active=None, frac_corrupt=None, corrupt_how='permute', corrupt_size=None,
+ corrupt_mean=None, diagnostics=False, return_support=False):
+ """
+ This method creates response from an linear + lss model
+
+ X: data matrix
+ m: number of interaction terms
+ r: max order of interaction
+ s: denotes number of linear terms in EACH interaction term
+ tau: threshold
+ sigma: standard deviation of noise
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+
+ :return
+ y_train: numpy array of shape (n)
+ """
+ n, p = X.shape
+ assert p >= m * r # Cannot have more interactions * size than the dimension
+ assert s <= r
+
+ def partial_linear_func(x,s,beta):
+ y = 0.0
+ count = 0
+ for j in range(m):
+ for i in range(s):
+ y += beta[count]*x[j*r+i]
+ count += 1
+ return y
+
+
+ def lss_func(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ for j in range(m):
+ lss_term_components = x_bool[j * r:j * r + r]
+ lss_term = int(all(lss_term_components))
+ y += lss_term * beta[j]
+ return y
+
+ def lss_vector_fun(x, beta, beta_linear):
+ x_bool = (x - tau) > 0
+ y = 0
+ max_iter = 100
+ features = np.arange(p)
+ support_idx = []
+ for j in range(m):
+ cnt = 0
+ while True:
+ int_features = np.concatenate(
+ [np.arange(j*r, j*r+s), np.random.choice(features, size=r-s, replace=False)]
+ )
+ lss_term_components = x_bool[:, int_features]
+ lss_term = np.apply_along_axis(all, 1, lss_term_components)
+ cnt += 1
+ if np.mean(lss_term) >= min_active or cnt > max_iter:
+ norm_constant = sum(np.var(x[:, (j*r):(j*r+s)], axis=0) * beta_linear[(j*s):((j+1)*s)]**2)
+ relative_beta = beta[j] / sum(beta_linear[(j*s):((j+1)*s)])
+ y += lss_term * relative_beta * np.sqrt(norm_constant) / np.std(lss_term)
+ features = list(set(features).difference(set(int_features)))
+ support_idx.append(int_features)
+ if cnt > max_iter:
+ warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active))
+ break
+ support_idx = np.stack(support_idx).ravel()
+ support = np.zeros(p)
+ for j in support_idx:
+ support[j] = 1
+ return y, support
+
+ beta_lss = generate_coef(beta, m)
+ beta_linear = generate_coef(beta, s*m)
+ if tau == 'median':
+ tau = np.median(X,axis = 0)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :],s,beta_linear ) for i in range(n)])
+ if min_active is None:
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss) for i in range(n)])
+ support = np.concatenate((np.ones(max(m * r, s)), np.zeros(X.shape[1] - max((m * r), s))))
+ else:
+ y_train_lss, support = lss_vector_fun(X, beta_lss, beta_linear)
+ y_train = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+ if heritability is not None:
+ sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5
+ if snr is not None:
+ sigma = (np.var(y_train) / snr) ** 0.5
+ if error_fun is None:
+ error_fun = np.random.randn
+
+ if frac_corrupt is None and corrupt_size is None:
+ y_train = y_train + sigma * error_fun(n)
+ else:
+ if frac_corrupt is None:
+ frac_corrupt = 0
+ num_corrupt = int(np.floor(frac_corrupt*len(y_train)))
+ corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt)
+ if corrupt_how == 'permute':
+ corrupt_array = y_train[corrupt_indices]
+ corrupt_array = random.sample(list(corrupt_array), len(corrupt_array))
+ for i,index in enumerate(corrupt_indices):
+ y_train[index] = corrupt_array[i]
+ y_train = y_train + sigma * error_fun(n)
+ elif corrupt_how == 'cauchy':
+ for i in range(len(y_train)):
+ if i in corrupt_indices:
+ y_train[i] = y_train[i] + sigma*np.random.standard_cauchy()
+ else:
+ y_train[i] = y_train[i] + sigma*error_fun()
+ elif corrupt_how == "leverage_constant":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(max(m*r, s), p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant")
+ elif corrupt_how == "leverage_normal":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(max(m*r, s), p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal")
+
+ if return_support:
+ return y_train, support, beta_lss
+ elif diagnostics:
+ return y_train, y_train_linear, y_train_lss
+ else:
+ return y_train
+
+
+def hierarchical_poly(X, sigma=None, m=1, r=1, beta=1, heritability=None, snr=None,
+ frac_corrupt=None, corrupt_how='permute', corrupt_size=None,
+ corrupt_mean=None, error_fun=None, return_support=False):
+ """
+ This method creates response from an Linear + LSS model
+
+ X: data matrix
+ m: number of interaction terms
+ r: max order of interaction
+ s: sparsity
+ sigma: standard deviation of noise
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+
+ :return
+ y_train: numpy array of shape (n)
+ """
+
+ n, p = X.shape
+ assert p >= m * r
+
+ def reg_func(x, beta):
+ y = 0
+ for i in range(m):
+ hier_term = 1.0
+ for j in range(r):
+ hier_term += x[i * r + j] * hier_term
+ y += hier_term * beta[i]
+ return y
+
+ beta = generate_coef(beta, m)
+ y_train = np.array([reg_func(X[i, :], beta) for i in range(n)])
+ if heritability is not None:
+ sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5
+ if snr is not None:
+ sigma = (np.var(y_train) / snr) ** 0.5
+ if error_fun is None:
+ error_fun = np.random.randn
+
+ if frac_corrupt is None and corrupt_size is None:
+ y_train = y_train + sigma * error_fun(n)
+ else:
+ if frac_corrupt is None:
+ frac_corrupt = 0
+ num_corrupt = int(np.floor(frac_corrupt*len(y_train)))
+ corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt)
+ if corrupt_how == 'permute':
+ corrupt_array = y_train[corrupt_indices]
+ corrupt_array = random.sample(list(corrupt_array), len(corrupt_array))
+ for i,index in enumerate(corrupt_indices):
+ y_train[index] = corrupt_array[i]
+ y_train = y_train + sigma * error_fun(n)
+ elif corrupt_how == 'cauchy':
+ for i in range(len(y_train)):
+ if i in corrupt_indices:
+ y_train[i] = y_train[i] + sigma*np.random.standard_cauchy()
+ else:
+ y_train[i] = y_train[i] + sigma*error_fun()
+ elif corrupt_how == "leverage_constant":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(m*r, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant")
+ elif corrupt_how == "leverage_normal":
+ if isinstance(corrupt_size, int):
+ corrupt_quantile = corrupt_size / n
+ else:
+ corrupt_quantile = corrupt_size
+ y_train = y_train + sigma * error_fun(n)
+ corrupt_idx = np.random.choice(range(m*r, p), size=1)
+ y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal")
+
+ if return_support:
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def logistic_model(X, s, beta=None, beta_grid=np.logspace(-4, 4, 100), heritability=None,
+ frac_label_corruption=None, return_support=False):
+ """
+ This method is used to create responses from a sum of squares model with hard sparsity
+ Parameters:
+ X: X matrix
+ s: sparsity
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+ Returns:
+ numpy array of shape (n)
+ """
+
+ def create_prob(x, beta):
+ linear_term = 0
+ for j in range(len(beta)):
+ linear_term += x[j] * beta[j]
+ prob = 1 / (1 + np.exp(-linear_term))
+ return prob
+
+ def create_y(x, beta):
+ linear_term = 0
+ for j in range(len(beta)):
+ linear_term += x[j] * beta[j]
+ prob = 1 / (1 + np.exp(-linear_term))
+ return (np.random.uniform(size=1) < prob) * 1
+
+ if heritability is None:
+ beta = generate_coef(beta, s)
+ y_train = np.array([create_y(X[i, :], beta) for i in range(len(X))]).ravel()
+ else:
+ # find beta to get desired heritability via adaptive grid search within eps=0.01
+ y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, s, create_prob, beta_grid)
+
+ if frac_label_corruption is None:
+ y_train = y_train
+ else:
+ corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train)))
+ y_train[corrupt_indices] = 1 - y_train[corrupt_indices]
+ if return_support:
+ support = np.concatenate((np.ones(s), np.zeros(X.shape[1] - s)))
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def logistic_lss_model(X, m, r, tau, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100),
+ min_active=None, frac_label_corruption=None, return_support=False):
+ """
+ This method is used to create responses from a logistic model model with lss
+ X: X matrix
+ s: sparsity
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+ Returns:
+ numpy array of shape (n)
+ """
+ n, p = X.shape
+
+ def lss_prob_func(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ for j in range(m):
+ lss_term_components = x_bool[j * r:j * r + r]
+ lss_term = int(all(lss_term_components))
+ y += lss_term * beta[j]
+ prob = 1 / (1 + np.exp(-y))
+ return prob
+
+ def lss_func(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ for j in range(m):
+ lss_term_components = x_bool[j * r:j * r + r]
+ lss_term = int(all(lss_term_components))
+ y += lss_term * beta[j]
+ prob = 1 / (1 + np.exp(-y))
+ return (np.random.uniform(size=1) < prob) * 1
+
+ def lss_vector_fun(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ max_iter = 100
+ features = np.arange(p)
+ support_idx = []
+ for j in range(m):
+ cnt = 0
+ while True:
+ int_features = np.random.choice(features, size=r, replace=False)
+ lss_term_components = x_bool[:, int_features]
+ lss_term = np.apply_along_axis(all, 1, lss_term_components)
+ cnt += 1
+ if np.mean(lss_term) >= min_active or cnt > max_iter:
+ y += lss_term * beta[j]
+ features = list(set(features).difference(set(int_features)))
+ support_idx.append(int_features)
+ if cnt > max_iter:
+ warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active))
+ break
+ prob = 1 / (1 + np.exp(-y))
+ y = (np.random.uniform(size=n) < prob) * 1
+ support_idx = np.stack(support_idx).ravel()
+ support = np.zeros(p)
+ for j in support_idx:
+ support[j] = 1
+ return y, support
+
+ if tau == 'median':
+ tau = np.median(X,axis = 0)
+
+ if heritability is None:
+ beta = generate_coef(beta, m)
+ if min_active is None:
+ y_train = np.array([lss_func(X[i, :], beta) for i in range(n)]).ravel()
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+ else:
+ y_train, support = lss_vector_fun(X, beta)
+ y_train = y_train.ravel()
+ else:
+ if min_active is not None:
+ raise ValueError("Cannot set heritability and min_active at the same time.")
+ # find beta to get desired heritability via adaptive grid search within eps=0.01 (need to jitter beta to reach higher signals)
+ y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, m, lss_prob_func, beta_grid, jitter_beta=True)
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+
+ if frac_label_corruption is None:
+ y_train = y_train
+ else:
+ corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train)))
+ y_train[corrupt_indices] = 1 - y_train[corrupt_indices]
+
+ if return_support:
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def logistic_partial_linear_lss_model(X, s, m, r, tau, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100),
+ min_active=None, frac_label_corruption=None, return_support=False):
+ """
+ This method is used to create responses from a logistic model model with lss
+ X: X matrix
+ s: sparsity
+ beta: coefficient vector. If beta not a vector, then assumed a constant
+ Returns:
+ numpy array of shape (n)
+ """
+ n, p = X.shape
+ assert p >= m * r
+
+ def partial_linear_func(x,s,beta):
+ y = 0.0
+ count = 0
+ for j in range(m):
+ for i in range(s):
+ y += beta[count]*x[j*r+i]
+ count += 1
+ return y
+
+ def lss_func(x, beta):
+ x_bool = (x - tau) > 0
+ y = 0
+ for j in range(m):
+ lss_term_components = x_bool[j * r:j * r + r]
+ lss_term = int(all(lss_term_components))
+ y += lss_term * beta[j]
+ return y
+
+ def logistic_link_func(y):
+ prob = 1 / (1 + np.exp(-y))
+ return (np.random.uniform(size=1) < prob) * 1
+
+ def logistic_prob_func(y):
+ prob = 1 / (1 + np.exp(-y))
+ return prob
+
+ def lss_vector_fun(x, beta, beta_linear):
+ x_bool = (x - tau) > 0
+ y = 0
+ max_iter = 100
+ features = np.arange(p)
+ support_idx = []
+ for j in range(m):
+ cnt = 0
+ while True:
+ int_features = np.concatenate(
+ [np.arange(j*r, j*r+s), np.random.choice(features, size=r-s, replace=False)]
+ )
+ lss_term_components = x_bool[:, int_features]
+ lss_term = np.apply_along_axis(all, 1, lss_term_components)
+ cnt += 1
+ if np.mean(lss_term) >= min_active or cnt > max_iter:
+ norm_constant = sum(np.var(x[:, (j*r):(j*r+s)], axis=0) * beta_linear[(j*s):((j+1)*s)]**2)
+ relative_beta = beta[j] / sum(beta_linear[(j*s):((j+1)*s)])
+ y += lss_term * relative_beta * np.sqrt(norm_constant) / np.std(lss_term)
+ features = list(set(features).difference(set(int_features)))
+ support_idx.append(int_features)
+ if cnt > max_iter:
+ warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active))
+ break
+ support_idx = np.stack(support_idx).ravel()
+ support = np.zeros(p)
+ for j in support_idx:
+ support[j] = 1
+ return y, support
+
+ if tau == 'median':
+ tau = np.median(X,axis = 0)
+
+ if heritability is None:
+ beta_lss = generate_coef(beta, m)
+ beta_linear = generate_coef(beta, s*m)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :],s,beta_linear ) for i in range(n)])
+ if min_active is None:
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss) for i in range(n)])
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+ else:
+ y_train_lss, support = lss_vector_fun(X, beta_lss, beta_linear)
+ y_train = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+ y_train = np.array([logistic_link_func(y_train[i]) for i in range(n)])
+ else:
+ if min_active is not None:
+ raise ValueError("Cannot set heritability and min_active at the same time.")
+ # find beta to get desired heritability via adaptive grid search within eps=0.01
+ eps = 0.01
+ max_iter = 1000
+ pves = {}
+ for idx, beta in enumerate(beta_grid):
+ beta_lss_vec = generate_coef(beta, m)
+ beta_linear_vec = generate_coef(beta, s*m)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)])
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)])
+ y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+ prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel()
+ np.random.seed(idx)
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ pves[(idx, beta)] = pve
+
+ (idx, beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability))
+ beta_lss_vec = generate_coef(beta, m)
+ beta_linear_vec = generate_coef(beta, s*m)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)])
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)])
+ y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+
+ prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel()
+ np.random.seed(idx)
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ if pve > heritability:
+ min_beta = beta_grid[idx-1]
+ max_beta = beta
+ else:
+ min_beta = beta
+ max_beta = beta_grid[idx+1]
+ cur_beta = (min_beta + max_beta) / 2
+ iter = 1
+ while np.abs(pve - heritability) > eps:
+ beta_lss_vec = generate_coef(cur_beta, m)
+ beta_linear_vec = generate_coef(cur_beta, s*m)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)])
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)])
+ y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+
+ prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel()
+ np.random.seed(iter + len(beta_grid))
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ pves[(iter + len(beta_grid), cur_beta)] = pve
+ if pve > heritability:
+ max_beta = cur_beta
+ else:
+ min_beta = cur_beta
+ beta = cur_beta
+ cur_beta = (min_beta + max_beta) / 2
+ iter += 1
+ if iter > max_iter:
+ (idx, cur_beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability))
+ beta_lss_vec = generate_coef(cur_beta, m)
+ beta_linear_vec = generate_coef(cur_beta, s*m)
+
+ y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)])
+ y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)])
+ y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)])
+
+ prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel()
+ np.random.seed(idx)
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ beta = cur_beta
+ break
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+
+ if frac_label_corruption is None:
+ y_train = y_train
+ else:
+ corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train)))
+ y_train[corrupt_indices] = 1 - y_train[corrupt_indices]
+
+ y_train = y_train.ravel()
+
+ if return_support:
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def logistic_hier_model(X, m, r, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100),
+ frac_label_corruption=None, return_support=False):
+
+ n, p = X.shape
+ assert p >= m * r
+
+ def reg_func(x, beta):
+ y = 0
+ for i in range(m):
+ hier_term = 1.0
+ for j in range(r):
+ hier_term += x[i * r + j] * hier_term
+ y += hier_term * beta[i]
+ return y
+
+ def logistic_link_func(y):
+ prob = 1 / (1 + np.exp(-y))
+ return (np.random.uniform(size=1) < prob) * 1
+
+ def prob_func(x, beta):
+ y = 0
+ for i in range(m):
+ hier_term = 1.0
+ for j in range(r):
+ hier_term += x[i * r + j] * hier_term
+ y += hier_term * beta[i]
+ return 1 / (1 + np.exp(-y))
+
+ if heritability is None:
+ beta = generate_coef(beta, m)
+ y_train = np.array([reg_func(X[i, :], beta) for i in range(n)])
+ y_train = np.array([logistic_link_func(y_train[i]) for i in range(n)])
+ else:
+ # find beta to get desired heritability via adaptive grid search within eps=0.01
+ y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, m, prob_func, beta_grid)
+
+ if frac_label_corruption is None:
+ y_train = y_train
+ else:
+ corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train)))
+ y_train[corrupt_indices] = 1 - y_train[corrupt_indices]
+ y_train = y_train.ravel()
+
+ if return_support:
+ support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r))))
+ return y_train, support, beta
+ else:
+ return y_train
+
+
+def logistic_heritability_search(X, heritability, s, prob_fun, beta_grid=np.logspace(-4, 4, 100),
+ eps=0.01, max_iter=1000, jitter_beta=False, return_pve=True):
+ pves = {}
+
+ # first search over beta grid
+ for idx, beta in enumerate(beta_grid):
+ np.random.seed(idx)
+ beta_vec = generate_coef(beta, s)
+ if jitter_beta:
+ beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape)
+ prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel()
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ pves[(idx, beta)] = pve
+
+ # find beta with heritability closest to desired heritability
+ (idx, beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability))
+ np.random.seed(idx)
+ beta_vec = generate_coef(beta, s)
+ if jitter_beta:
+ beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape)
+ prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel()
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+
+ # search nearby beta to get closer to desired heritability
+ if pve > heritability:
+ min_beta = beta_grid[idx-1]
+ max_beta = beta
+ else:
+ min_beta = beta
+ max_beta = beta_grid[idx+1]
+ cur_beta = (min_beta + max_beta) / 2
+ iter = 1
+ while np.abs(pve - heritability) > eps:
+ np.random.seed(iter + len(beta_grid))
+ beta_vec = generate_coef(cur_beta, s)
+ if jitter_beta:
+ beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape)
+ prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel()
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ pves[(iter + len(beta_grid), cur_beta)] = pve
+ if pve > heritability:
+ max_beta = cur_beta
+ else:
+ min_beta = cur_beta
+ cur_beta = (min_beta + max_beta) / 2
+ beta = beta_vec
+ iter += 1
+ if iter > max_iter:
+ (idx, cur_beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability))
+ np.random.seed(idx)
+ beta_vec = generate_coef(cur_beta, s)
+ if jitter_beta:
+ beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape)
+ prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel()
+ y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1
+ pve = np.var(prob_train) / np.var(y_train)
+ beta = beta_vec
+ break
+
+ if return_pve:
+ return y_train, beta, pve, pves
+ else:
+ return y_train, beta
+
+
+def entropy_X(n, scale=False):
+ x1 = np.random.choice([0, 1], (n, 1), replace=True)
+ x2 = np.random.normal(0, 1, (n, 1))
+ x3 = np.random.choice(np.arange(4), (n, 1), replace=True)
+ x4 = np.random.choice(np.arange(10), (n, 1), replace=True)
+ x5 = np.random.choice(np.arange(20), (n, 1), replace=True)
+ X = np.concatenate((x1, x2, x3, x4, x5), axis=1)
+ if scale:
+ X = (X - X.mean()) / X.std()
+ return X
+
+
+def entropy_y(X, c=3, return_support=False):
+ if any(X[:, 0] < 0):
+ x = (X[:, 0] > 0) * 1
+ else:
+ x = X[:, 0]
+ prob = ((c - 2) * x + 1) / c
+ y = (np.random.uniform(size=len(prob)) < prob) * 1
+ if return_support:
+ support = np.array([0, 1, 0, 0, 0])
+ beta = None
+ return y, support, beta
+ else:
+ return y
+
+
+class IndexedArray(np.ndarray):
+ def __new__(cls, input_array, index=None):
+ obj = np.asarray(input_array).view(cls)
+ obj.index = index
+ return obj
+
+ def __array_finalize__(self, obj):
+ if obj is None:
+ return
+ self.index = getattr(obj, 'index', None)
+
+#%%
diff --git a/feature_importance/scripts/viz.R b/feature_importance/scripts/viz.R
new file mode 100644
index 0000000..c61cee0
--- /dev/null
+++ b/feature_importance/scripts/viz.R
@@ -0,0 +1,781 @@
+library(magrittr)
+
+
+# reformat results
+reformat_results <- function(results, prediction = FALSE) {
+ if (!prediction) {
+ results_grouped <- results %>%
+ dplyr::group_by(index) %>%
+ tidyr::nest(fi_scores = var:(tidyselect::last_col())) %>%
+ dplyr::ungroup() %>%
+ dplyr::select(-index) %>%
+ # join fi+model to get method column
+ tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>%
+ dplyr::mutate(
+ # get rid of duplicate RF in r2f method name
+ method = ifelse(stringr::str_detect(method, "^r2f.*RF$"),
+ stringr::str_remove(method, "\\_RF$"), method),
+ # unnest columns
+ prediction_score = purrr::map_dbl(
+ fi_scores,
+ function(x) {
+ ifelse("prediction_score" %in% colnames(x), x$prediction_score[[1]], NA)
+ }
+ ),
+ tauAP = purrr::map_dbl(
+ fi_scores,
+ function(x) {
+ ifelse("tauAP" %in% colnames(x), x$tauAP[[1]], NA)
+ }
+ ),
+ RBO = purrr::map_dbl(
+ fi_scores,
+ function(x) {
+ ifelse("RBO" %in% colnames(x), x$RBO[[1]], NA)
+ }
+ )
+ )
+ } else {
+ results_grouped <- results %>%
+ dplyr::group_by(index) %>%
+ tidyr::nest(predictions = sample_id:(tidyselect::last_col())) %>%
+ dplyr::ungroup() %>%
+ dplyr::select(-index) %>%
+ dplyr::mutate(
+ # get rid of duplicate RF in r2f method name
+ method = ifelse(stringr::str_detect(model, "^r2f.*RF$"),
+ stringr::str_remove(model, "\\_RF$"), model)
+ )
+ }
+ return(results_grouped)
+}
+
+# plot metrics (mean value across repetitions with error bars)
+plot_metrics <- function(results,
+ metric = c("rocauc", "prauc"),# "tpr",
+ #"median_signal_rank", "max_signal_rank"),
+ x_str, facet_str, linetype_str = NULL,
+ point_size = 1, line_size = 1, errbar_width = 0,
+ alpha = 0.5, inside_legend = FALSE,
+ manual_color_palette = NULL,
+ show_methods = NULL,
+ method_labels = ggplot2::waiver(),
+ alpha_values = NULL,
+ legend_position = NULL,
+ custom_theme = vthemes::theme_vmodern(size_preset = "medium")) {
+ if (is.null(show_methods)) {
+ show_methods <- sort(unique(results$method))
+ }
+ metric_names <- metric
+ plt_df <- results %>%
+ dplyr::select(rep, method,
+ tidyselect::all_of(c(metric, x_str, facet_str, linetype_str))) %>%
+ tidyr::pivot_longer(
+ cols = tidyselect::all_of(metric), names_to = "metric"
+ ) %>%
+ dplyr::group_by(
+ method, metric, dplyr::across(tidyselect::all_of(c(x_str, facet_str, linetype_str)))
+ ) %>%
+ dplyr::summarise(mean = mean(value),
+ sd = sd(value) / sqrt(dplyr::n()),
+ .groups = "keep") %>%
+ dplyr::filter(method %in% show_methods) %>%
+ dplyr::mutate(
+ method = factor(method, levels = show_methods),
+ metric = forcats::fct_recode(
+ factor(metric, levels = metric_names),
+ AUROC = "rocauc", AUPRC = "prauc", Accuracy = "accuracy"
+ )
+ )
+
+ if (is.null(linetype_str)) {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::geom_point(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = method),
+ size = point_size
+ ) +
+ ggplot2::geom_line(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = method),
+ size = line_size
+ ) +
+ ggplot2::geom_errorbar(
+ ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd,
+ color = method, alpha = method, group = method),
+ width = errbar_width, show_guide = FALSE
+ )
+ } else {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::geom_point(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str))),
+ size = point_size
+ ) +
+ ggplot2::geom_line(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str)),
+ linetype = !!rlang::sym(linetype_str)),
+ size = line_size
+ ) +
+ ggplot2::geom_errorbar(
+ ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd,
+ color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str))),
+ width = errbar_width, show_guide = FALSE
+ )
+ }
+ if (!is.null(manual_color_palette)) {
+ if (is.null(alpha_values)) {
+ alpha_values <- c(1, rep(alpha, length(method_labels) - 1))
+ }
+ plt <- plt +
+ ggplot2::scale_color_manual(
+ values = manual_color_palette, labels = method_labels
+ ) +
+ ggplot2::scale_alpha_manual(
+ values = alpha_values, labels = method_labels
+ )
+ }
+ if (!is.null(custom_theme)) {
+ plt <- plt + custom_theme
+ }
+
+ if (!is.null(facet_str)) {
+ plt <- plt +
+ ggplot2::facet_grid(reformulate(facet_str, "metric"), scales = "free")
+ } else if (length(metric) > 1) {
+ plt <- plt +
+ ggplot2::facet_wrap(~ metric, scales = "free")
+ }
+
+ if (inside_legend) {
+ if (is.null(legend_position)) {
+ legend_position <- c(0.75, 0.3)
+ }
+ plt <- plt +
+ ggplot2::theme(
+ legend.title = ggplot2::element_blank(),
+ legend.position = legend_position,
+ legend.background = ggplot2::element_rect(
+ color = "slategray", size = 0.1
+ )
+ )
+ }
+
+ return(plt)
+}
+
+# plot restricted metrics (mean value across repetitions with error bars)
+plot_restricted_metrics <- function(results, metric = c("rocauc", "prauc"),
+ x_str, facet_str,
+ quantiles = c(0.1, 0.2, 0.3, 0.4),
+ point_size = 1, line_size = 1, errbar_width = 0,
+ alpha = 0.5, inside_legend = FALSE,
+ manual_color_palette = NULL,
+ show_methods = NULL,
+ method_labels = ggplot2::waiver(),
+ custom_theme = vthemes::theme_vmodern(size_preset = "medium")) {
+ if (is.null(show_methods)) {
+ show_methods <- sort(unique(results$method))
+ }
+ results <- results %>%
+ dplyr::select(rep, method, fi_scores,
+ tidyselect::all_of(c(x_str, facet_str))) %>%
+ dplyr::mutate(
+ vars_ordered = purrr::map(
+ fi_scores,
+ function(fi_df) {
+ fi_df %>%
+ dplyr::filter(!is.na(cor_with_signal)) %>%
+ dplyr::arrange(-cor_with_signal) %>%
+ dplyr::pull(var)
+ }
+ )
+ )
+
+ plt_df_ls <- list()
+ for (q in quantiles) {
+ plt_df_ls[[as.character(q)]] <- results %>%
+ dplyr::mutate(
+ restricted_metrics = purrr::map2_dfr(
+ fi_scores, vars_ordered,
+ function(fi_df, ignore_vars) {
+ ignore_vars <- ignore_vars[1:round(q * length(ignore_vars))]
+ auroc_r <- fi_df %>%
+ dplyr::filter(!(var %in% ignore_vars)) %>%
+ yardstick::roc_auc(
+ truth = factor(true_support, levels = c("1", "0")), importance,
+ event_level = "first"
+ ) %>%
+ dplyr::pull(.estimate)
+ auprc_r <- fi_df %>%
+ dplyr::filter(!(var %in% ignore_vars)) %>%
+ yardstick::pr_auc(
+ truth = factor(true_support, levels = c("1", "0")), importance,
+ event_level = "first"
+ ) %>%
+ dplyr::pull(.estimate)
+ return(data.frame(restricted_auroc = auroc_r,
+ restricted_auprc = auprc_r))
+ }
+ )
+ ) %>%
+ tidyr::unnest(restricted_metrics) %>%
+ tidyr::pivot_longer(
+ cols = c(restricted_auroc, restricted_auprc), names_to = "metric"
+ ) %>%
+ dplyr::group_by(
+ method, metric, dplyr::across(tidyselect::all_of(c(x_str, facet_str)))
+ ) %>%
+ dplyr::summarise(mean = mean(value),
+ sd = sd(value) / sqrt(dplyr::n()),
+ .groups = "keep") %>%
+ dplyr::ungroup() %>%
+ dplyr::filter(method %in% show_methods) %>%
+ dplyr::mutate(
+ method = factor(method, levels = show_methods),
+ metric = forcats::fct_recode(
+ factor(metric, levels = c("restricted_auroc", "restricted_auprc")),
+ `Restricted AUROC` = "restricted_auroc",
+ `Restricted AUPRC` = "restricted_auprc"
+ )
+ )
+ }
+
+ plt_df <- purrr::map_dfr(plt_df_ls, ~.x, .id = ".threshold") %>%
+ dplyr::mutate(.threshold = as.numeric(.threshold))
+
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::geom_point(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = method),
+ size = point_size
+ ) +
+ ggplot2::geom_line(
+ ggplot2::aes(x = .data[[x_str]], y = mean,
+ color = method, alpha = method, group = method),
+ size = line_size
+ ) +
+ ggplot2::geom_errorbar(
+ ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd,
+ color = method, alpha = method, group = method),
+ width = errbar_width, show_guide = FALSE
+ )
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(
+ values = manual_color_palette, labels = method_labels
+ ) +
+ ggplot2::scale_alpha_manual(
+ values = c(1, rep(alpha, length(method_labels) - 1)),
+ labels = method_labels
+ )
+ }
+ if (!is.null(custom_theme)) {
+ plt <- plt + custom_theme
+ }
+
+ if (!is.null(facet_str)) {
+ formula <- sprintf("metric + .threshold ~ %s", paste0(facet_str, collapse = " + "))
+ plt <- plt +
+ ggplot2::facet_grid(as.formula(formula))
+ } else {
+ plt <- plt +
+ ggplot2::facet_wrap(.threshold ~ metric, scales = "free")
+ }
+
+ if (inside_legend) {
+ plt <- plt +
+ ggplot2::theme(
+ legend.title = ggplot2::element_blank(),
+ legend.position = c(0.75, 0.3),
+ legend.background = ggplot2::element_rect(
+ color = "slategray", size = 0.1
+ )
+ )
+ }
+
+ return(plt)
+}
+
+# plot true positive rate across # positives
+plot_tpr <- function(results, facet_vars, point_size = 0.85,
+ manual_color_palette = NULL,
+ show_methods = NULL,
+ method_labels = ggplot2::waiver(),
+ custom_theme = vthemes::theme_vmodern(size_preset = "medium")) {
+ if (is.null(results)) {
+ return(NULL)
+ }
+
+ if (is.null(show_methods)) {
+ show_methods <- sort(unique(results$method))
+ }
+
+ facet_names <- names(facet_vars)
+ names(facet_vars) <- NULL
+
+ plt_df <- results %>%
+ dplyr::mutate(
+ fi_scores = mapply(name = fi, scores_df = fi_scores,
+ function(name, scores_df) {
+ scores_df <- scores_df %>%
+ dplyr::mutate(
+ ranking = rank(-importance,
+ ties.method = "random")
+ ) %>%
+ dplyr::arrange(ranking) %>%
+ dplyr::mutate(
+ .tp = cumsum(true_support) / sum(true_support)
+ )
+ return(scores_df)
+ }, SIMPLIFY = FALSE)
+ ) %>%
+ tidyr::unnest(fi_scores) %>%
+ dplyr::select(tidyselect::all_of(facet_vars), rep, method, ranking, .tp) %>%
+ dplyr::group_by(
+ dplyr::across(tidyselect::all_of(facet_vars)), method, ranking
+ ) %>%
+ dplyr::summarise(.tp = mean(.tp), .groups = "keep") %>%
+ dplyr::mutate(method = factor(method, levels = show_methods)) %>%
+ dplyr::ungroup()
+
+ if (!is.null(facet_names)) {
+ for (i in 1:length(facet_vars)) {
+ facet_var <- facet_vars[i]
+ facet_name <- facet_names[i]
+ if (facet_name != "") {
+ plt_df <- plt_df %>%
+ dplyr::mutate(dplyr::across(
+ tidyselect::all_of(facet_var),
+ ~factor(sprintf("%s = %s", facet_name, .x),
+ levels = sprintf("%s = %s", facet_name, sort(unique(.x))))
+ ))
+ }
+ }
+ }
+
+ if (length(facet_vars) == 1) {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = ranking, y = .tp, color = method) +
+ ggplot2::geom_line(size = point_size) +
+ ggplot2::facet_wrap(reformulate(facet_vars))
+ } else {
+ plt <- ggplot2::ggplot(plt_df) +
+ ggplot2::aes(x = ranking, y = .tp, color = method) +
+ ggplot2::geom_line(size = point_size) +
+ ggplot2::facet_grid(reformulate(facet_vars[1], facet_vars[2]))
+ }
+ plt <- plt +
+ ggplot2::labs(x = "Top n", y = "True Positive Rate",
+ fill = "Method", color = "Method")
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(
+ values = manual_color_palette, labels = method_labels
+ )
+ }
+ if (!is.null(custom_theme)) {
+ plt <- plt + custom_theme
+ }
+
+ return(plt)
+}
+
+# plot stability results
+plot_perturbation_stability <- function(results,
+ facet_rows = "heritability_name",
+ facet_cols = "rho_name",
+ param_name = NULL,
+ facet_row_names = "Avg Rank (PVE = %s)",
+ group_fun = NULL,
+ descending_methods = NULL,
+ manual_color_palette = NULL,
+ show_methods = NULL,
+ method_labels = ggplot2::waiver(),
+ plot_types = c("boxplot", "errbar"),
+ save_dir = ".",
+ save_filename = NULL,
+ fig_height = 11,
+ fig_width = 11,
+ ...) {
+ plot_types <- match.arg(plot_types, several.ok = TRUE)
+
+ my_theme <- vthemes::theme_vmodern(
+ size_preset = "medium", bg_color = "white", grid_color = "white",
+ axis.title = ggplot2::element_text(size = 12, face = "plain"),
+ legend.title = ggplot2::element_blank(),
+ legend.text = ggplot2::element_text(size = 9),
+ legend.text.align = 0,
+ plot.title = ggplot2::element_blank()
+ )
+
+ if (is.null(group_fun)) {
+ group_fun <- function(var, sig_ids, cnsig_ids) {
+ dplyr::case_when(
+ var %in% sig_ids ~ "Sig",
+ var %in% cnsig_ids ~ "C-NSig",
+ TRUE ~ "NSig"
+ ) %>%
+ factor(levels = c("Sig", "C-NSig", "NSig"))
+ }
+ }
+
+ if (!is.null(show_methods)) {
+ results <- results %>%
+ dplyr::filter(fi %in% show_methods)
+ if (!identical(method_labels, ggplot2::waiver())) {
+ method_names <- show_methods
+ names(method_names) <- method_labels
+ results$fi <- do.call(forcats::fct_recode,
+ args = c(list(results$fi), as.list(method_names)))
+ results$fi <- factor(results$fi, levels = method_labels)
+ method_labels <- ggplot2::waiver()
+ }
+ }
+
+ if (!is.null(descending_methods)) {
+ results <- results %>%
+ dplyr::mutate(
+ importance = ifelse(fi %in% descending_methods, -importance, importance)
+ )
+ }
+
+ rankings <- results %>%
+ dplyr::group_by(
+ rep, fi,
+ dplyr::across(tidyselect::all_of(c(facet_rows, facet_cols, param_name))),
+ ) %>%
+ dplyr::mutate(
+ rank = rank(-importance),
+ group = group_fun(var, ...)
+ ) %>%
+ dplyr::ungroup()
+
+ agg_rankings <- rankings %>%
+ dplyr::group_by(
+ rep, fi, group,
+ dplyr::across(tidyselect::all_of(c(facet_rows, facet_cols, param_name)))
+ ) %>%
+ dplyr::summarise(
+ avgrank = mean(rank),
+ .groups = "keep"
+ ) %>%
+ dplyr::ungroup()
+
+ ymin <- min(agg_rankings$avgrank)
+ ymax <- max(agg_rankings$avgrank)
+
+ for (type in plot_types) {
+ plt_ls <- list()
+ for (val in unique(agg_rankings[[facet_rows]])) {
+ if (identical(type, "boxplot")) {
+ plt <- agg_rankings %>%
+ dplyr::filter(.data[[facet_rows]] == val) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = group, y = avgrank, color = fi) +
+ ggplot2::geom_boxplot()
+ } else if (identical(type, "errbar")) {
+ plt <- agg_rankings %>%
+ dplyr::filter(.data[[facet_rows]] == val) %>%
+ dplyr::group_by(
+ fi, group, dplyr::across(tidyselect::all_of(facet_cols))
+ ) %>%
+ dplyr::summarise(
+ .mean = mean(avgrank),
+ .sd = sd(avgrank),
+ .groups = "keep"
+ ) %>%
+ dplyr::ungroup() %>%
+ ggplot2::ggplot() +
+ # ggplot2::geom_point(
+ # ggplot2::aes(x = group, y = .mean, color = fi, group = fi),
+ # position = ggplot2::position_dodge2(width = 0.8, padding = 0.8)
+ # ) +
+ ggplot2::geom_errorbar(
+ ggplot2::aes(x = group, ymin = .mean - .sd, ymax = .mean + .sd,
+ color = fi, group = fi),
+ position = ggplot2::position_dodge2(width = 0, padding = 0.5),
+ width = 0.5, show_guide = FALSE
+ )
+ }
+ plt <- plt +
+ ggplot2::coord_cartesian(ylim = c(ymin, ymax)) +
+ ggplot2::facet_grid(~ .data[[facet_cols]], labeller = ggplot2::label_parsed) +
+ my_theme +
+ ggplot2::theme(
+ panel.grid.major = ggplot2::element_line(colour = "#d9d9d9"),
+ panel.grid.major.x = ggplot2::element_blank(),
+ panel.grid.minor.x = ggplot2::element_blank(),
+ axis.line.y = ggplot2::element_blank(),
+ axis.ticks.y = ggplot2::element_blank(),
+ legend.position = "right"
+ ) +
+ ggplot2::labs(x = "Feature Groups", y = sprintf(facet_row_names, val))
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels)
+ }
+ if (length(plt_ls) != 0) {
+ plt <- plt +
+ ggplot2::theme(strip.text = ggplot2::element_blank())
+ }
+ plt_ls[[as.character(val)]] <- plt
+ }
+ agg_plt <- patchwork::wrap_plots(plt_ls) +
+ patchwork::plot_layout(ncol = 1, guides = "collect")
+ if (!is.null(save_filename)) {
+ ggplot2::ggsave(
+ filename = file.path(save_dir,
+ sprintf("%s_%s_aggregated.pdf", save_filename, type)),
+ plot = agg_plt, units = "in", width = fig_width, height = fig_height
+ )
+ }
+ }
+
+ unagg_plt <- NULL
+ if (!is.null(param_name)) {
+ plt_ls <- list()
+ for (val in unique(agg_rankings[[facet_rows]])) {
+ plt <- agg_rankings %>%
+ dplyr::filter(.data[[facet_rows]] == val) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(x = .data[[param_name]], y = avgrank, color = fi) +
+ ggplot2::geom_boxplot() +
+ ggplot2::coord_cartesian(ylim = c(ymin, ymax)) +
+ ggplot2::facet_grid(
+ reformulate(c(facet_cols, "group"), "fi"),
+ labeller = ggplot2::label_parsed
+ ) +
+ my_theme +
+ ggplot2::theme(
+ panel.grid.major = ggplot2::element_line(colour = "#d9d9d9"),
+ panel.grid.major.x = ggplot2::element_blank(),
+ panel.grid.minor.x = ggplot2::element_blank(),
+ axis.line.y = ggplot2::element_blank(),
+ axis.ticks.y = ggplot2::element_blank(),
+ legend.position = "none"
+ ) +
+ ggplot2::labs(x = param_name, y = sprintf(facet_row_names, val))
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels)
+ }
+ plt_ls[[as.character(val)]] <- plt
+ }
+ unagg_plt <- patchwork::wrap_plots(plt_ls) +
+ patchwork::plot_layout(guides = "collect")
+
+ if (!is.null(save_filename)) {
+ ggplot2::ggsave(
+ filename = file.path(save_dir,
+ sprintf("%s_unaggregated.pdf", save_filename)),
+ plot = unagg_plt, units = "in", width = fig_width, height = fig_height
+ )
+ }
+ }
+ return(list(agg = agg_plt, unagg = unagg_plt))
+}
+
+
+plot_top_stability <- function(results,
+ group_id = NULL,
+ top_r = 10,
+ show_max_features = 5,
+ varnames = NULL,
+ base_method = "MDI+_ridge",
+ return_df = FALSE,
+ descending_methods = NULL,
+ manual_color_palette = NULL,
+ show_methods = NULL,
+ method_labels = ggplot2::waiver(),
+ ...) {
+
+ plt_ls <- list()
+ if (!is.null(show_methods)) {
+ results <- results %>%
+ dplyr::filter(fi %in% show_methods)
+ if (!identical(method_labels, ggplot2::waiver())) {
+ method_names <- show_methods
+ names(method_names) <- method_labels
+ results$fi <- do.call(forcats::fct_recode,
+ args = c(list(results$fi), as.list(method_names)))
+ results$fi <- factor(results$fi, levels = method_labels)
+ method_labels <- ggplot2::waiver()
+ }
+ }
+
+ if (!is.null(descending_methods)) {
+ results <- results %>%
+ dplyr::mutate(
+ importance = ifelse(fi %in% descending_methods, -importance, importance)
+ )
+ }
+
+ if (is.null(varnames)) {
+ varnames <- 0:(length(unique(results$var)) - 1)
+ }
+ varnames_df <- data.frame(var = 0:(length(varnames) - 1), Feature = varnames)
+
+ rankings <- results %>%
+ dplyr::group_by(
+ rep, fi, dplyr::across(tidyselect::all_of(group_id))
+ ) %>%
+ dplyr::mutate(
+ rank = rank(-importance),
+ in_top_r = importance >= sort(importance, decreasing = TRUE)[top_r]
+ ) %>%
+ dplyr::ungroup()
+
+ stability_df <- rankings %>%
+ dplyr::group_by(
+ fi, var, dplyr::across(tidyselect::all_of(group_id))
+ ) %>%
+ dplyr::summarise(
+ stability_score = mean(in_top_r),
+ .groups = "keep"
+ ) %>%
+ dplyr::ungroup()
+
+ n_nonzero_stability_df <- stability_df %>%
+ dplyr::group_by(fi, dplyr::across(tidyselect::all_of(group_id))) %>%
+ dplyr::summarise(
+ n_features = sum(stability_score > 0),
+ .groups = "keep"
+ ) %>%
+ dplyr::ungroup()
+
+ if (!is.null(group_id)) {
+ order_groups <- n_nonzero_stability_df %>%
+ dplyr::group_by(dplyr::across(tidyselect::all_of(group_id))) %>%
+ dplyr::summarise(
+ order = min(n_features)
+ ) %>%
+ dplyr::arrange(order) %>%
+ dplyr::pull(tidyselect::all_of(group_id)) %>%
+ unique()
+
+ n_nonzero_stability_df <- n_nonzero_stability_df %>%
+ dplyr::mutate(
+ dplyr::across(
+ tidyselect::all_of(group_id), ~factor(.x, levels = order_groups)
+ )
+ )
+
+ ytext_label_colors <- n_nonzero_stability_df %>%
+ dplyr::group_by(dplyr::across(tidyselect::all_of(group_id))) %>%
+ dplyr::summarise(
+ is_best_method = n_features[fi == base_method] == min(n_features)
+ ) %>%
+ dplyr::mutate(
+ color = ifelse(is_best_method, "black", "#4a86e8")
+ ) %>%
+ dplyr::arrange(tidyselect::all_of(group_id)) %>%
+ dplyr::pull(color)
+
+ plt <- vdocs::plot_horizontal_dotplot(
+ n_nonzero_stability_df,
+ x_str = "n_features", y_str = group_id, color_str = "fi",
+ theme_options = list(size_preset = "xlarge")
+ ) +
+ ggplot2::labs(
+ x = sprintf("Number of Distinct Features in Top 10 Across %s RF Fits",
+ length(unique(results$rep))),
+ color = "Method"
+ ) +
+ ggplot2::scale_y_discrete(limits = rev) +
+ ggplot2::theme(
+ axis.text.y = ggplot2::element_text(color = rev(ytext_label_colors))
+ )
+
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE))
+ }
+
+ plt_ls[["Summary"]] <- plt
+
+ } else {
+ plt <- stability_df %>%
+ dplyr::left_join(y = varnames_df, by = "var") %>%
+ dplyr::filter(stability_score > 0) %>%
+ dplyr::left_join(y = n_nonzero_stability_df, by = "fi") %>%
+ dplyr::mutate(fi = sprintf("%s (# features = %s)", fi, n_features)) %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(
+ x = reorder(Feature, -stability_score), y = stability_score
+ ) +
+ ggplot2::facet_wrap(~ fi, scales = "free_x", ncol = 1) +
+ ggplot2::geom_bar(stat = "identity") +
+ ggplot2::labs(
+ x = "Gene", y = sprintf("Proportion of RF Fits in Top %s", top_r)
+ ) +
+ vthemes::theme_vmodern(x_text_angle = TRUE, size_preset = "medium")
+
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_color_manual(values = manual_color_palette,
+ labels = method_labels,
+ guide = ggplot2::guide_legend(reverse = TRUE))
+ }
+ plt_ls[["Non-zero Stability Scores Per Method"]] <- plt
+ }
+
+ if (is.null(group_id)) {
+ rankings <- rankings %>%
+ dplyr::mutate(
+ .group = "Stability of Top Features"
+ )
+ group_id <- ".group"
+ }
+
+ for (group in unique(rankings[[group_id]])) {
+ plt <- rankings %>%
+ dplyr::filter(.data[[group_id]] == group) %>%
+ dplyr::group_by(fi, var) %>%
+ dplyr::summarise(
+ mean_rank = mean(rank),
+ median_rank = median(rank),
+ `SD(Feature Rank)` = sd(rank),
+ .groups = "keep"
+ ) %>%
+ dplyr::group_by(fi) %>%
+ dplyr::mutate(
+ agg_feature_rank = rank(mean_rank, ties.method = "random")
+ ) %>%
+ dplyr::filter(agg_feature_rank <= show_max_features) %>%
+ dplyr::mutate(`Average Feature Rank` = as.factor(agg_feature_rank)) %>%
+ dplyr::left_join(y = varnames_df, by = "var") %>%
+ dplyr::rename("Method" = "fi") %>%
+ dplyr::ungroup() %>%
+ ggplot2::ggplot() +
+ ggplot2::aes(
+ x = `Average Feature Rank`, y = `SD(Feature Rank)`, fill = Method, label = Feature
+ ) +
+ ggplot2::geom_bar(
+ stat = "identity", position = "dodge"
+ ) +
+ ggplot2::labs(title = group) +
+ vthemes::scale_fill_vmodern(discrete = TRUE) +
+ vthemes::theme_vmodern()
+
+ if (!is.null(manual_color_palette)) {
+ plt <- plt +
+ ggplot2::scale_fill_manual(values = manual_color_palette,
+ labels = method_labels)
+ }
+ plt_ls[[group]] <- plt
+ }
+
+ if (return_df) {
+ return(list(plot_ls = plt_ls, rankings = rankings, stability = stability_df))
+ } else {
+ return(plt_ls)
+ }
+}
+
+
diff --git a/feature_importance/util.py b/feature_importance/util.py
new file mode 100644
index 0000000..cb95e03
--- /dev/null
+++ b/feature_importance/util.py
@@ -0,0 +1,200 @@
+import copy
+import os
+import warnings
+from functools import partial
+from os.path import dirname
+from os.path import join as oj
+from typing import Any, Dict, Tuple
+
+import numpy as np
+import pandas as pd
+from sklearn import model_selection
+from sklearn.metrics import confusion_matrix, precision_recall_curve, auc, average_precision_score, mean_absolute_error
+from sklearn.preprocessing import label_binarize
+from sklearn.utils._encode import _unique
+from sklearn import metrics
+from imodels.importance.ppms import huber_loss
+
+DATASET_PATH = oj(dirname(os.path.realpath(__file__)), 'data')
+
+
+class ModelConfig:
+ def __init__(self,
+ name: str, cls,
+ vary_param: str = None, vary_param_val: Any = None,
+ other_params: Dict[str, Any] = {},
+ model_type: str = None):
+ """
+ name: str
+ Name of the model.
+ vary_param: str
+ Name of the parameter to be varied
+ model_type: str
+ Type of model. ID is used to pair with FIModel.
+ """
+
+ self.name = name
+ self.cls = cls
+ self.model_type = model_type
+ self.vary_param = vary_param
+ self.vary_param_val = vary_param_val
+ self.kwargs = {}
+ if self.vary_param is not None:
+ self.kwargs[self.vary_param] = self.vary_param_val
+ self.kwargs = {**self.kwargs, **other_params}
+
+ def __repr__(self):
+ return self.name
+
+
+class FIModelConfig:
+ def __init__(self,
+ name: str, cls, ascending = True,
+ splitting_strategy: str = None,
+ vary_param: str = None, vary_param_val: Any = None,
+ other_params: Dict[str, Any] = {},
+ model_type: str = None):
+ """
+ ascending: boolean
+ Whether or not feature importances should be ranked in ascending or
+ descending order. Default is True, indicating that higher feature
+ importance score is more important.
+ splitting_strategy: str
+ See util.apply_splitting_strategy(). Common inputs are "train-test" and None
+ vary_param: str
+ Name of the parameter to be varied
+ model_type: str
+ Type of model. ID is used to pair FIModel with Model.
+ """
+
+ assert splitting_strategy in {
+ 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata',
+ 'train-test-prediction', None
+ }
+
+ self.name = name
+ self.cls = cls
+ self.ascending = ascending
+ self.model_type = model_type
+ self.splitting_strategy = splitting_strategy
+ self.vary_param = vary_param
+ self.vary_param_val = vary_param_val
+ self.kwargs = {}
+ if self.vary_param is not None:
+ self.kwargs[self.vary_param] = self.vary_param_val
+ self.kwargs = {**self.kwargs, **other_params}
+
+ def __repr__(self):
+ return self.name
+
+
+def tp(y_true, y_pred):
+ conf_mat = confusion_matrix(y_true, y_pred)
+ return conf_mat[1][1]
+
+
+def fp(y_true, y_pred):
+ conf_mat = confusion_matrix(y_true, y_pred)
+ return conf_mat[0][1]
+
+
+def neg(y_true, y_pred):
+ return sum(y_pred == 0)
+
+
+def pos(y_true, y_pred):
+ return sum(y_pred == 1)
+
+
+def specificity_score(y_true, y_pred):
+ conf_mat = confusion_matrix(y_true, y_pred)
+ return conf_mat[0][0] / (conf_mat[0][0] + conf_mat[0][1])
+
+
+def auprc_score(y_true, y_score, multi_class="raise"):
+ assert multi_class in ["raise", "ovr"]
+ n_classes = len(np.unique(y_true))
+ if n_classes <= 2:
+ precision, recall, _ = precision_recall_curve(y_true, y_score)
+ return auc(recall, precision)
+ else:
+ # ovr is same as multi-label
+ if multi_class == "raise":
+ raise ValueError("Must set multi_class='ovr' to evaluate multi-class predictions.")
+ classes = _unique(y_true)
+ y_true_multilabel = label_binarize(y_true, classes=classes)
+ return average_precision_score(y_true_multilabel, y_score)
+
+
+def auroc_score(y_true, y_score, multi_class="raise", **kwargs):
+ assert multi_class in ["raise", "ovr"]
+ n_classes = len(np.unique(y_true))
+ if n_classes <= 2:
+ return metrics.roc_auc_score(y_true, y_score, multi_class=multi_class, **kwargs)
+ else:
+ # ovr is same as multi-label
+ if multi_class == "raise":
+ raise ValueError("Must set multi_class='ovr' to evaluate multi-class predictions.")
+ classes = _unique(y_true)
+ y_true_multilabel = label_binarize(y_true, classes=classes)
+ return metrics.roc_auc_score(y_true_multilabel, y_score, **kwargs)
+
+
+def neg_mean_absolute_error(y_true, y_pred, **kwargs):
+ return -mean_absolute_error(y_true, y_pred, **kwargs)
+
+
+def neg_huber_loss(y_true, y_pred, **kwargs):
+ return -huber_loss(y_true, y_pred, **kwargs)
+
+
+def restricted_roc_auc_score(y_true, y_score, ignored_indices=[]):
+ """
+ Compute AUROC score for only a subset of the samples
+
+ :param y_true:
+ :param y_score:
+ :param ignored_indices:
+ :return:
+ """
+ n_examples = len(y_true)
+ mask = [i for i in range(n_examples) if i not in ignored_indices]
+ restricted_auc = metrics.roc_auc_score(np.array(y_true)[mask], np.array(y_score)[mask])
+ return restricted_auc
+
+
+def compute_nsg_feat_corr_w_sig_subspace(signal_features, nonsignal_features, normalize=True):
+
+ if normalize:
+ normalized_nsg_features = nonsignal_features / np.linalg.norm(nonsignal_features, axis=0)
+ else:
+ normalized_nsg_features = nonsignal_features
+
+ q, r = np.linalg.qr(signal_features)
+ projections = np.linalg.norm(q.T @ normalized_nsg_features, axis=0)
+ # nsg_feat_ranked_by_projection = np.argsort(projections)
+
+ return projections
+
+
+def apply_splitting_strategy(X: np.ndarray,
+ y: np.ndarray,
+ splitting_strategy: str,
+ split_seed: str) -> Tuple[Any, Any, Any, Any, Any, Any]:
+ if splitting_strategy in {'train-test-lowdata', 'train-tune-test-lowdata'}:
+ test_size = 0.90 # X.shape[0] - X.shape[0] * 0.1
+ elif splitting_strategy == "train-test":
+ test_size = 0.33
+ else:
+ test_size = 0.2
+
+ X_train, X_test, y_train, y_test = model_selection.train_test_split(
+ X, y, test_size=test_size, random_state=split_seed)
+ X_tune = None
+ y_tune = None
+
+ if splitting_strategy in {'train-tune-test', 'train-tune-test-lowdata'}:
+ X_train, X_tune, y_train, y_tune = model_selection.train_test_split(
+ X_train, y_train, test_size=0.2, random_state=split_seed)
+
+ return X_train, X_tune, X_test, y_train, y_tune, y_test