diff --git a/feature_importance/01_run_importance_simulations.py b/feature_importance/01_run_importance_simulations.py new file mode 100644 index 0000000..874f431 --- /dev/null +++ b/feature_importance/01_run_importance_simulations.py @@ -0,0 +1,475 @@ +# Example usage: run in command line +# cd feature_importance/ +# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache +# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd + +import copy +import os +from os.path import join as oj +import glob +import argparse +import pickle as pkl +import time +import warnings +from scipy import stats +import dask +from dask.distributed import Client +import numpy as np +import pandas as pd +from tqdm import tqdm +import sys +from collections import defaultdict +from typing import Callable, List, Tuple +import itertools +from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score + +sys.path.append(".") +sys.path.append("..") +sys.path.append("../..") +import fi_config +from util import ModelConfig, FIModelConfig, tp, fp, neg, pos, specificity_score, auroc_score, auprc_score, compute_nsg_feat_corr_w_sig_subspace, apply_splitting_strategy + +warnings.filterwarnings("ignore", message="Bins whose width") + + +def compare_estimators(estimators: List[ModelConfig], + fi_estimators: List[FIModelConfig], + X, y, support: List, + metrics: List[Tuple[str, Callable]], + args, ) -> Tuple[dict, dict]: + """Calculates results given estimators, feature importance estimators, datasets, and metrics. + Called in run_comparison + """ + if type(estimators) != list: + raise Exception("First argument needs to be a list of Models") + if type(metrics) != list: + raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs") + + # initialize results + results = defaultdict(lambda: []) + + # loop over model estimators + for model in tqdm(estimators, leave=False): + est = model.cls(**model.kwargs) + + # get kwargs for all fi_ests + fi_kwargs = {} + for fi_est in fi_estimators: + fi_kwargs.update(fi_est.kwargs) + + # get groups of estimators for each splitting strategy + fi_ests_dict = defaultdict(list) + for fi_est in fi_estimators: + fi_ests_dict[fi_est.splitting_strategy].append(fi_est) + + # loop over splitting strategies + for splitting_strategy, fi_ests in fi_ests_dict.items(): + # implement provided splitting strategy + if splitting_strategy is not None: + X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, splitting_strategy, args.split_seed) + else: + X_train = X + X_tune = X + X_test = X + y_train = y + y_tune = y + y_test = y + + # fit model + est.fit(X_train, y_train) + + # compute correlation between signal and nonsignal features + x_cor = np.empty(len(support)) + x_cor[:] = np.NaN + x_cor[support == 0] = compute_nsg_feat_corr_w_sig_subspace(X_train[:, support == 1], X_train[:, support == 0]) + + # loop over fi estimators + for fi_est in fi_ests: + metric_results = { + 'model': model.name, + 'fi': fi_est.name, + 'splitting_strategy': splitting_strategy + } + start = time.time() + fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs) + end = time.time() + support_df = pd.DataFrame({"var": np.arange(len(support)), + "true_support": support, + "cor_with_signal": x_cor}) + metric_results['fi_scores'] = pd.merge(copy.deepcopy(fi_score), support_df, on="var", how="left") + if np.max(support) != np.min(support): + for i, (met_name, met) in enumerate(metrics): + if met is not None: + imp_vals = copy.deepcopy(fi_score["importance"]) + imp_vals[imp_vals == float("-inf")] = -sys.maxsize - 1 + imp_vals[imp_vals == float("inf")] = sys.maxsize - 1 + if fi_est.ascending: + imp_vals[np.isnan(imp_vals)] = -sys.maxsize - 1 + metric_results[met_name] = met(support, imp_vals) + else: + imp_vals[np.isnan(imp_vals)] = sys.maxsize - 1 + metric_results[met_name] = met(support, -imp_vals) + metric_results['time'] = end - start + + # initialize results with metadata and metric results + kwargs: dict = model.kwargs # dict + for k in kwargs: + results[k].append(kwargs[k]) + for k in fi_kwargs: + if k in fi_est.kwargs: + results[k].append(str(fi_est.kwargs[k])) + else: + results[k].append(None) + for met_name, met_val in metric_results.items(): + results[met_name].append(met_val) + return results + + +def run_comparison(path: str, + X, y, support: List, + metrics: List[Tuple[str, Callable]], + estimators: List[ModelConfig], + fi_estimators: List[FIModelConfig], + args): + estimator_name = estimators[0].name.split(' - ')[0] + fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_estimators) \ + if fi_estimator.model_type in estimators[0].model_type] + model_comparison_files_all = [oj(path, f'{estimator_name}_{fi_estimator.name}_comparisons.pkl') \ + for fi_estimator in fi_estimators_all] + if args.parallel_id is not None: + model_comparison_files_all = [f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) \ + for model_comparison_file in model_comparison_files_all] + + fi_estimators = [] + model_comparison_files = [] + for model_comparison_file, fi_estimator in zip(model_comparison_files_all, fi_estimators_all): + if os.path.isfile(model_comparison_file) and not args.ignore_cache: + print( + f'{estimator_name} with {fi_estimator.name} results already computed and cached. use --ignore_cache to recompute') + else: + fi_estimators.append(fi_estimator) + model_comparison_files.append(model_comparison_file) + + if len(fi_estimators) == 0: + return + + results = compare_estimators(estimators=estimators, + fi_estimators=fi_estimators, + X=X, y=y, support=support, + metrics=metrics, + args=args) + + estimators_list = [e.name for e in estimators] + metrics_list = [m[0] for m in metrics] + + df = pd.DataFrame.from_dict(results) + df['split_seed'] = args.split_seed + if args.nosave_cols is not None: + nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")]) + else: + nosave_cols = [] + for col in nosave_cols: + if col in df.columns: + df = df.drop(columns=[col]) + + for model_comparison_file, fi_estimator in zip(model_comparison_files, fi_estimators): + output_dict = { + # metadata + 'sim_name': args.config, + 'estimators': estimators_list, + 'fi_estimators': fi_estimator.name, + 'metrics': metrics_list, + + # actual values + 'df': df.loc[df.fi == fi_estimator.name], + } + pkl.dump(output_dict, open(model_comparison_file, 'wb')) + return df + + +def get_metrics(): + return [('rocauc', auroc_score), ('prauc', auprc_score)] + + +def reformat_results(results): + results = results.reset_index().drop(columns=['index']) + fi_scores = pd.concat(results.pop('fi_scores').to_dict()). \ + reset_index(level=0).rename(columns={'level_0': 'index'}) + results_df = pd.merge(results, fi_scores, left_index=True, right_on="index") + return results_df + + +def run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, fi_ests, metrics, args): + os.makedirs(oj(path, val_name, "rep" + str(i)), exist_ok=True) + np.random.seed(i) + max_iter = 100 + iter = 0 + while iter <= max_iter: # regenerate data if y is constant + X = X_dgp(**X_params_dict) + y, support, beta = y_dgp(X, **y_params_dict, return_support=True) + if not all(y == y[0]): + break + iter += 1 + if iter > max_iter: + raise ValueError("Response y is constant.") + if args.omit_vars is not None: + omit_vars = np.unique([int(x.strip()) for x in args.omit_vars.split(",")]) + support = np.delete(support, omit_vars) + X = np.delete(X, omit_vars, axis=1) + del beta # note: beta is not currently supported when using omit_vars + + for est in ests: + results = run_comparison(path=oj(path, val_name, "rep" + str(i)), + X=X, y=y, support=support, + metrics=metrics, + estimators=est, + fi_estimators=fi_ests, + args=args) + return True + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + default_dir = os.getenv("SCRATCH") + if default_dir is not None: + default_dir = oj(default_dir, "feature_importance", "results") + else: + default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results') + + parser.add_argument('--nreps', type=int, default=2) + parser.add_argument('--model', type=str, default=None) # , default='c4') + parser.add_argument('--fi_model', type=str, default=None) # , default='c4') + parser.add_argument('--config', type=str, default='test') + parser.add_argument('--omit_vars', type=str, default=None) # comma-separated string of variables to omit + parser.add_argument('--nosave_cols', type=str, default="prediction_model") + + # for multiple reruns, should support varying split_seed + parser.add_argument('--ignore_cache', action='store_true', default=False) + parser.add_argument('--verbose', action='store_true', default=True) + parser.add_argument('--parallel', action='store_true', default=False) + parser.add_argument('--parallel_id', nargs='+', type=int, default=None) + parser.add_argument('--n_cores', type=int, default=None) + parser.add_argument('--split_seed', type=int, default=0) + parser.add_argument('--results_path', type=str, default=default_dir) + + # arguments for rmd output of results + parser.add_argument('--create_rmd', action='store_true', default=False) + parser.add_argument('--show_vars', type=int, default=None) + + args = parser.parse_args() + + if args.parallel: + if args.n_cores is None: + print(os.getenv("SLURM_CPUS_ON_NODE")) + n_cores = int(os.getenv("SLURM_CPUS_ON_NODE")) + else: + n_cores = args.n_cores + client = Client(n_workers=n_cores) + + ests, fi_ests, \ + X_dgp, X_params_dict, y_dgp, y_params_dict, \ + vary_param_name, vary_param_vals = fi_config.get_fi_configs(args.config) + + metrics = get_metrics() + + if args.model: + ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests)) + if args.fi_model: + fi_ests = list(filter(lambda x: args.fi_model.lower() == x[0].name.lower(), fi_ests)) + + if len(ests) == 0: + raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model) + if len(fi_ests) == 0: + raise ValueError('No valid FI estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model) + if args.verbose: + print('running', args.config, + 'ests', ests, + 'fi_ests', fi_ests) + print('\tsaving to', args.results_path) + + if args.omit_vars is not None: + results_dir = oj(args.results_path, args.config + "_omitted_vars") + else: + results_dir = oj(args.results_path, args.config) + + if isinstance(vary_param_name, list): + path = oj(results_dir, "varying_" + "_".join(vary_param_name), "seed" + str(args.split_seed)) + else: + path = oj(results_dir, "varying_" + vary_param_name, "seed" + str(args.split_seed)) + os.makedirs(path, exist_ok=True) + + eval_out = defaultdict(list) + + vary_type = None + if isinstance(vary_param_name, list): # multiple parameters are being varied + # get parameters that are being varied over and identify whether it's a DGP/method/fi_method argument + keys, values = zip(*vary_param_vals.items()) + vary_param_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)] + vary_type = {} + for vary_param_dict in vary_param_dicts: + for param_name, param_val in vary_param_dict.items(): + if param_name in X_params_dict.keys() and param_name in y_params_dict.keys(): + raise ValueError('Cannot vary over parameter in both X and y DGPs.') + elif param_name in X_params_dict.keys(): + vary_type[param_name] = "dgp" + X_params_dict[param_name] = vary_param_vals[param_name][param_val] + elif param_name in y_params_dict.keys(): + vary_type[param_name] = "dgp" + y_params_dict[param_name] = vary_param_vals[param_name][param_val] + else: + est_kwargs = list( + itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))])) + fi_est_kwargs = list( + itertools.chain(*[list(fi_est.kwargs.keys()) for fi_est in list(itertools.chain(*fi_ests))])) + if param_name in est_kwargs: + vary_type[param_name] = "est" + elif param_name in fi_est_kwargs: + vary_type[param_name] = "fi_est" + else: + raise ValueError('Invalid vary_param_name.') + + if args.parallel: + futures = [ + dask.delayed(run_simulation)(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, + y_params_dict, y_dgp, ests, fi_ests, metrics, args) for i in + range(args.nreps)] + results = dask.compute(*futures) + else: + results = [ + run_simulation(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, y_params_dict, + y_dgp, ests, fi_ests, metrics, args) for i in range(args.nreps)] + assert all(results) + + else: # only on parameter is being varied over + # get parameter that is being varied over and identify whether it's a DGP/method/fi_method argument + for val_name, val in vary_param_vals.items(): + if vary_param_name in X_params_dict.keys() and vary_param_name in y_params_dict.keys(): + raise ValueError('Cannot vary over parameter in both X and y DGPs.') + elif vary_param_name in X_params_dict.keys(): + vary_type = "dgp" + X_params_dict[vary_param_name] = val + elif vary_param_name in y_params_dict.keys(): + vary_type = "dgp" + y_params_dict[vary_param_name] = val + else: + est_kwargs = list(itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))])) + fi_est_kwargs = list( + itertools.chain(*[list(fi_est.kwargs.keys()) for fi_est in list(itertools.chain(*fi_ests))])) + if vary_param_name in est_kwargs: + vary_type = "est" + elif vary_param_name in fi_est_kwargs: + vary_type = "fi_est" + else: + raise ValueError('Invalid vary_param_name.') + + if args.parallel: + futures = [ + dask.delayed(run_simulation)(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, + fi_ests, metrics, args) for i in range(args.nreps)] + results = dask.compute(*futures) + else: + results = [run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, fi_ests, + metrics, args) for i in range(args.nreps)] + assert all(results) + + print('completed all experiments successfully!') + + # get model file names + model_comparison_files_all = [] + for est in ests: + estimator_name = est[0].name.split(' - ')[0] + fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_ests) \ + if fi_estimator.model_type in est[0].model_type] + model_comparison_files = [f'{estimator_name}_{fi_estimator.name}_comparisons.pkl' for fi_estimator in + fi_estimators_all] + model_comparison_files_all += model_comparison_files + + # aggregate results + results_list = [] + if isinstance(vary_param_name, list): + for vary_param_dict in vary_param_dicts: + val_name = "_".join(vary_param_dict.values()) + + for i in range(args.nreps): + all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(path, val_name, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + + for param_name, param_val in vary_param_dict.items(): + val = vary_param_vals[param_name][param_val] + if vary_type[param_name] == "dgp": + if np.isscalar(val): + results.insert(0, param_name, val) + else: + results.insert(0, param_name, [val for i in range(results.shape[0])]) + results.insert(1, param_name + "_name", param_val) + elif vary_type[param_name] == "est" or vary_type[param_name] == "fi_est": + results.insert(0, param_name + "_name", copy.deepcopy(results[param_name])) + results.insert(0, 'rep', i) + results_list.append(results) + else: + for val_name, val in vary_param_vals.items(): + for i in range(args.nreps): + all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(path, val_name, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + if vary_type == "dgp": + if np.isscalar(val): + results.insert(0, vary_param_name, val) + else: + results.insert(0, vary_param_name, [val for i in range(results.shape[0])]) + results.insert(1, vary_param_name + "_name", val_name) + results.insert(2, 'rep', i) + elif vary_type == "est" or vary_type == "fi_est": + results.insert(0, vary_param_name + "_name", copy.deepcopy(results[vary_param_name])) + results.insert(1, 'rep', i) + results_list.append(results) + results_merged = pd.concat(results_list, axis=0) + pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb')) + results_df = reformat_results(results_merged) + results_df.to_csv(oj(path, 'results.csv'), index=False) + + print('merged and saved all experiment results successfully!') + + # create R markdown summary of results + if args.create_rmd: + if args.show_vars is None: + show_vars = 'NULL' + else: + show_vars = args.show_vars + + if isinstance(vary_param_name, list): + vary_param_name = "; ".join(vary_param_name) + + sim_rmd = os.path.basename(results_dir) + '_simulation_results.Rmd' + os.system( + 'cp {} \'{}\''.format(oj("rmd", "simulation_results.Rmd"), sim_rmd) + ) + os.system( + 'Rscript -e "rmarkdown::render(\'{}\', params = list(results_dir = \'{}\', vary_param_name = \'{}\', seed = {}, keep_vars = {}), output_file = \'{}\', quiet = TRUE)"'.format( + sim_rmd, + results_dir, vary_param_name, str(args.split_seed), str(show_vars), + oj(path, "simulation_results.html")) + ) + os.system('rm \'{}\''.format(sim_rmd)) + print("created rmd of simulation results successfully!") + +# %% diff --git a/feature_importance/02_run_importance_real_data.py b/feature_importance/02_run_importance_real_data.py new file mode 100644 index 0000000..4a51745 --- /dev/null +++ b/feature_importance/02_run_importance_real_data.py @@ -0,0 +1,402 @@ +# Example usage: run in command line +# cd feature_importance/ +# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache +# python 01_run_simulations.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd + +import copy +import os +from os.path import join as oj +import glob +import argparse +import pickle as pkl +import time +import warnings +from scipy import stats +import dask +from dask.distributed import Client +import numpy as np +import pandas as pd +from tqdm import tqdm +import sys +from collections import defaultdict +from typing import Callable, List, Tuple +import itertools +from functools import partial + +sys.path.append(".") +sys.path.append("..") +sys.path.append("../..") +import fi_config +from util import ModelConfig, FIModelConfig, apply_splitting_strategy, auroc_score, auprc_score + +from sklearn.metrics import accuracy_score, f1_score, recall_score, \ + precision_score, average_precision_score, r2_score, explained_variance_score, \ + mean_squared_error, mean_absolute_error, log_loss + +warnings.filterwarnings("ignore", message="Bins whose width") + + +def compare_estimators(estimators: List[ModelConfig], + fi_estimators: List[FIModelConfig], + X, y, args, ) -> Tuple[dict, dict]: + """Calculates results given estimators, feature importance estimators, and datasets. + Called in run_comparison + """ + if type(estimators) != list: + raise Exception("First argument needs to be a list of Models") + + # initialize results + results = defaultdict(lambda: []) + + # loop over model estimators + for model in tqdm(estimators, leave=False): + est = model.cls(**model.kwargs) + + # get kwargs for all fi_ests + fi_kwargs = {} + for fi_est in fi_estimators: + fi_kwargs.update(fi_est.kwargs) + + # get groups of estimators for each splitting strategy + fi_ests_dict = defaultdict(list) + for fi_est in fi_estimators: + fi_ests_dict[fi_est.splitting_strategy].append(fi_est) + + # loop over splitting strategies + for splitting_strategy, fi_ests in fi_ests_dict.items(): + # implement provided splitting strategy + if splitting_strategy is not None: + X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, splitting_strategy, args.split_seed) + if splitting_strategy == "train-test-prediction": + X_test_pred = copy.deepcopy(X_test) + y_test_pred = copy.deepcopy(y_test) + X_test = X_train + y_test = y_train + else: + X_train = X + X_tune = X + X_test = X + y_train = y + y_tune = y + y_test = y + + # fit model + est.fit(X_train, y_train) + + # loop over fi estimators + for fi_est in fi_ests: + metric_results = { + 'model': model.name, + 'fi': fi_est.name, + 'splitting_strategy': splitting_strategy + } + start = time.time() + fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs) + end = time.time() + metric_results['fi_scores'] = copy.deepcopy(fi_score) + metric_results['time'] = end - start + + if splitting_strategy == "train-test-prediction" and args.eval_top_ks is not None: + fi_rankings = fi_score.sort_values("importance", ascending=not fi_est.ascending) + pred_est = copy.deepcopy(est) + metrics = get_metrics(args.mode) + if args.eval_top_ks == "auto": + eval_top_ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 15, + 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100] + else: + eval_top_ks = [int(k.strip()) for k in args.eval_top_ks.split(",")] + pred_metrics_out = [] + for k in eval_top_ks: + top_k_features = fi_rankings["var"][:k] + Xk_train = X_train[:, top_k_features] + Xk_test_pred = X_test_pred[:, top_k_features] + pred_est.fit(Xk_train, y_train) + y_pred = pred_est.predict(Xk_test_pred) + if args.mode != 'regression': + y_pred_proba = pred_est.predict_proba(Xk_test_pred) + if args.mode == 'binary_classification': + y_pred_proba = y_pred_proba[:, 1] + else: + y_pred_proba = y_pred + for met_name, met in metrics: + if met is not None: + if args.mode == 'regression' \ + or met_name in ['accuracy', 'f1', 'precision', 'recall']: + pred_metrics_out.append({ + "k": k, "metric": met_name, "metric_value": met(y_test_pred, y_pred) + }) + else: + pred_metrics_out.append({ + "k": k, "metric": met_name, "metric_value": met(y_test_pred, y_pred_proba) + }) + metric_results["pred_metrics"] = copy.deepcopy(pd.DataFrame(pred_metrics_out)) + + # initialize results with metadata and results + kwargs: dict = model.kwargs # dict + for k in kwargs: + results[k].append(kwargs[k]) + for k in fi_kwargs: + if k in fi_est.kwargs: + results[k].append(str(fi_est.kwargs[k])) + else: + results[k].append(None) + for met_name, met_val in metric_results.items(): + results[met_name].append(met_val) + return results + + +def run_comparison(path: str, + X, y, + estimators: List[ModelConfig], + fi_estimators: List[FIModelConfig], + args): + estimator_name = estimators[0].name.split(' - ')[0] + fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_estimators) \ + if fi_estimator.model_type in estimators[0].model_type] + model_comparison_files_all = [oj(path, f'{estimator_name}_{fi_estimator.name}_comparisons.pkl') \ + for fi_estimator in fi_estimators_all] + if args.parallel_id is not None: + model_comparison_files_all = [f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) \ + for model_comparison_file in model_comparison_files_all] + + fi_estimators = [] + model_comparison_files = [] + for model_comparison_file, fi_estimator in zip(model_comparison_files_all, fi_estimators_all): + if os.path.isfile(model_comparison_file) and not args.ignore_cache: + print( + f'{estimator_name} with {fi_estimator.name} results already computed and cached. use --ignore_cache to recompute') + else: + fi_estimators.append(fi_estimator) + model_comparison_files.append(model_comparison_file) + + if len(fi_estimators) == 0: + return + + results = compare_estimators(estimators=estimators, + fi_estimators=fi_estimators, + X=X, y=y, + args=args) + + estimators_list = [e.name for e in estimators] + + df = pd.DataFrame.from_dict(results) + df['split_seed'] = args.split_seed + if args.nosave_cols is not None: + nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")]) + else: + nosave_cols = [] + for col in nosave_cols: + if col in df.columns: + df = df.drop(columns=[col]) + + for model_comparison_file, fi_estimator in zip(model_comparison_files, fi_estimators): + output_dict = { + # metadata + 'sim_name': args.config, + 'estimators': estimators_list, + 'fi_estimators': fi_estimator.name, + + # actual values + 'df': df.loc[df.fi == fi_estimator.name], + } + pkl.dump(output_dict, open(model_comparison_file, 'wb')) + return df + + +def reformat_results(results): + results = results.reset_index().drop(columns=['index']) + fi_scores = pd.concat(results.pop('fi_scores').to_dict()). \ + reset_index(level=0).rename(columns={'level_0': 'index'}) + if "pred_metrics" in results.columns: + pred_metrics = pd.concat(results.pop('pred_metrics').to_dict()). \ + reset_index(level=0).rename(columns={'level_0': 'index'}) + else: + pred_metrics = None + results_df = pd.merge(results, fi_scores, left_index=True, right_on="index") + if pred_metrics is not None: + pred_results_df = pd.merge(results, pred_metrics, left_index=True, right_on="index") + else: + pred_results_df = None + return results_df, pred_results_df + + +def get_metrics(mode: str = 'regression'): + if mode == 'binary_classification': + return [ + ('rocauc', auroc_score), + ('prauc', auprc_score), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', f1_score), + ('recall', recall_score), + ('precision', precision_score), + ('avg_precision', average_precision_score) + ] + elif mode == 'multiclass_classification': + return [ + ('rocauc', partial(auroc_score, multi_class="ovr")), + ('prauc', partial(auprc_score, multi_class="ovr")), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', partial(f1_score, average='micro')), + ('recall', partial(recall_score, average='micro')), + ('precision', partial(precision_score, average='micro')) + ] + elif mode == 'regression': + return [ + ('r2', r2_score), + ('explained_variance', explained_variance_score), + ('mean_squared_error', mean_squared_error), + ('mean_absolute_error', mean_absolute_error), + ] + + +def run_simulation(i, path, Xpath, ypath, ests, fi_ests, args): + X_df = pd.read_csv(Xpath) + y_df = pd.read_csv(ypath) + if args.response_idx is None: + keep_cols = y_df.columns + else: + keep_cols = [args.response_idx] + for col in keep_cols: + y = y_df[col].to_numpy().ravel() + keep_idx = ~pd.isnull(y) + X = X_df[keep_idx].to_numpy() + y = y[keep_idx] + if y_df.shape[1] > 1: + output_path = oj(path, col) + else: + output_path = path + os.makedirs(oj(output_path, "rep" + str(i)), exist_ok=True) + for est in ests: + for idx in range(len(est)): + if "random_state" in est[idx].kwargs.keys(): + est[idx].kwargs["random_state"] = i + results = run_comparison( + path=oj(output_path, "rep" + str(i)), + X=X, y=y, + estimators=est, + fi_estimators=fi_ests, + args=args + ) + + return True + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + default_dir = os.getenv("SCRATCH") + if default_dir is not None: + default_dir = oj(default_dir, "feature_importance", "results") + else: + default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results') + + parser.add_argument('--nreps', type=int, default=1) + parser.add_argument('--mode', type=str, default='binary_classification') + parser.add_argument('--model', type=str, default=None) + parser.add_argument('--fi_model', type=str, default=None) + parser.add_argument('--config', type=str, default='test_real_data') + parser.add_argument('--response_idx', type=str, default=None) + parser.add_argument('--nosave_cols', type=str, default=None) + parser.add_argument('--eval_top_ks', type=str, default="auto") + + # for multiple reruns, should support varying split_seed + parser.add_argument('--ignore_cache', action='store_true', default=False) + parser.add_argument('--verbose', action='store_true', default=True) + parser.add_argument('--parallel', action='store_true', default=False) + parser.add_argument('--parallel_id', nargs='+', type=int, default=None) + parser.add_argument('--n_cores', type=int, default=None) + parser.add_argument('--split_seed', type=int, default=0) + parser.add_argument('--results_path', type=str, default=default_dir) + + args = parser.parse_args() + + assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'} + + if args.parallel: + if args.n_cores is None: + print(os.getenv("SLURM_CPUS_ON_NODE")) + n_cores = int(os.getenv("SLURM_CPUS_ON_NODE")) + else: + n_cores = args.n_cores + client = Client(n_workers=n_cores) + + ests, fi_ests, Xpath, ypath = fi_config.get_fi_configs(args.config, real_data=True) + + if args.model: + ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests)) + if args.fi_model: + fi_ests = list(filter(lambda x: args.fi_model.lower() == x[0].name.lower(), fi_ests)) + + if len(ests) == 0: + raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model) + if len(fi_ests) == 0: + raise ValueError('No valid FI estimators', 'sim', args.config, 'models', args.model, 'fi', args.fi_model) + if args.verbose: + print('running', args.config, + 'ests', ests, + 'fi_ests', fi_ests) + print('\tsaving to', args.results_path) + + results_dir = oj(args.results_path, args.config) + path = oj(results_dir, "seed" + str(args.split_seed)) + os.makedirs(path, exist_ok=True) + + eval_out = defaultdict(list) + + if args.parallel: + futures = [dask.delayed(run_simulation)(i, path, Xpath, ypath, ests, fi_ests, args) for i in range(args.nreps)] + results = dask.compute(*futures) + else: + results = [run_simulation(i, path, Xpath, ypath, ests, fi_ests, args) for i in range(args.nreps)] + assert all(results) + + print('completed all experiments successfully!') + + # get model file names + model_comparison_files_all = [] + for est in ests: + estimator_name = est[0].name.split(' - ')[0] + fi_estimators_all = [fi_estimator for fi_estimator in itertools.chain(*fi_ests) \ + if fi_estimator.model_type in est[0].model_type] + model_comparison_files = [f'{estimator_name}_{fi_estimator.name}_comparisons.pkl' for fi_estimator in + fi_estimators_all] + model_comparison_files_all += model_comparison_files + + # aggregate results + y_df = pd.read_csv(ypath) + results_list = [] + for col in y_df.columns: + if y_df.shape[1] > 1: + output_path = oj(path, col) + else: + output_path = path + for i in range(args.nreps): + all_files = glob.glob(oj(output_path, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(output_path, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + results.insert(0, 'rep', i) + if y_df.shape[1] > 1: + results.insert(1, 'y_task', col) + results_list.append(results) + + results_merged = pd.concat(results_list, axis=0) + pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb')) + results_df, pred_results_df = reformat_results(results_merged) + results_df.to_csv(oj(path, 'results.csv'), index=False) + if pred_results_df is not None: + pred_results_df.to_csv(oj(path, 'pred_results.csv'), index=False) + + print('merged and saved all experiment results successfully!') + +# %% diff --git a/feature_importance/03_run_prediction_simulations.py b/feature_importance/03_run_prediction_simulations.py new file mode 100644 index 0000000..0cddec9 --- /dev/null +++ b/feature_importance/03_run_prediction_simulations.py @@ -0,0 +1,422 @@ +# Example usage: run in command line +# cd feature_importance/ +# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache +# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd + +import copy +import os +from os.path import join as oj +import glob +import argparse +import pickle as pkl +import time +import warnings +from scipy import stats +import dask +from dask.distributed import Client +import numpy as np +import pandas as pd +from tqdm import tqdm +import sys +from collections import defaultdict +from typing import Callable, List, Tuple +import itertools +from functools import partial + +sys.path.append(".") +sys.path.append("..") +sys.path.append("../..") +import fi_config +from util import ModelConfig, apply_splitting_strategy, auroc_score, auprc_score + +from sklearn.metrics import accuracy_score, f1_score, recall_score, \ + precision_score, average_precision_score, r2_score, explained_variance_score, \ + mean_squared_error, mean_absolute_error, log_loss + +warnings.filterwarnings("ignore", message="Bins whose width") + + +def compare_estimators(estimators: List[ModelConfig], + X, y, + metrics: List[Tuple[str, Callable]], + args, rep) -> Tuple[dict, dict]: + """Calculates results given estimators, feature importance estimators, and datasets. + Called in run_comparison + """ + if type(estimators) != list: + raise Exception("First argument needs to be a list of Models") + if type(metrics) != list: + raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs") + + # initialize results + results = defaultdict(lambda: []) + + if args.splitting_strategy is not None: + X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy( + X, y, args.splitting_strategy, args.split_seed + rep) + else: + X_train = X + X_tune = X + X_test = X + y_train = y + y_tune = y + y_test = y + + # loop over model estimators + for model in tqdm(estimators, leave=False): + est = model.cls(**model.kwargs) + + start = time.time() + est.fit(X_train, y_train) + end = time.time() + + metric_results = {'model': model.name} + y_pred = est.predict(X_test) + if args.mode != 'regression': + y_pred_proba = est.predict_proba(X_test) + if args.mode == 'binary_classification': + y_pred_proba = y_pred_proba[:, 1] + else: + y_pred_proba = y_pred + for met_name, met in metrics: + if met is not None: + if args.mode == 'regression' \ + or met_name in ['accuracy', 'f1', 'precision', 'recall']: + metric_results[met_name] = met(y_test, y_pred) + else: + metric_results[met_name] = met(y_test, y_pred_proba) + metric_results['predictions'] = copy.deepcopy(pd.DataFrame(y_pred_proba)) + metric_results['time'] = end - start + + # initialize results with metadata and metric results + kwargs: dict = model.kwargs # dict + for k in kwargs: + results[k].append(kwargs[k]) + for met_name, met_val in metric_results.items(): + results[met_name].append(met_val) + + return results + + +def run_comparison(rep: int, + path: str, + X, y, + metrics: List[Tuple[str, Callable]], + estimators: List[ModelConfig], + args): + estimator_name = estimators[0].name.split(' - ')[0] + model_comparison_file = oj(path, f'{estimator_name}_comparisons.pkl') + if args.parallel_id is not None: + model_comparison_file = f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) + + if os.path.isfile(model_comparison_file) and not args.ignore_cache: + print(f'{estimator_name} results already computed and cached. use --ignore_cache to recompute') + return + + results = compare_estimators(estimators=estimators, + X=X, y=y, + metrics=metrics, + args=args, + rep=rep) + + estimators_list = [e.name for e in estimators] + + df = pd.DataFrame.from_dict(results) + df['split_seed'] = args.split_seed + rep + if args.nosave_cols is not None: + nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")]) + else: + nosave_cols = [] + for col in nosave_cols: + if col in df.columns: + df = df.drop(columns=[col]) + + output_dict = { + # metadata + 'sim_name': args.config, + 'estimators': estimators_list, + + # actual values + 'df': df, + } + pkl.dump(output_dict, open(model_comparison_file, 'wb')) + + return df + + +def reformat_results(results): + results = results.reset_index().drop(columns=['index']) + predictions = pd.concat(results.pop('predictions').to_dict()). \ + reset_index(level=[0, 1]).rename(columns={'level_0': 'index', 'level_1': 'sample_id'}) + results_df = pd.merge(results, predictions, left_index=True, right_on="index") + return results_df + + +def get_metrics(mode: str = 'regression'): + if mode == 'binary_classification': + return [ + ('rocauc', auroc_score), + ('prauc', auprc_score), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', f1_score), + ('recall', recall_score), + ('precision', precision_score), + ('avg_precision', average_precision_score) + ] + elif mode == 'multiclass_classification': + return [ + ('rocauc', partial(auroc_score, multi_class="ovr")), + ('prauc', partial(auprc_score, multi_class="ovr")), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', partial(f1_score, average='micro')), + ('recall', partial(recall_score, average='micro')), + ('precision', partial(precision_score, average='micro')) + ] + elif mode == 'regression': + return [ + ('r2', r2_score), + ('explained_variance', explained_variance_score), + ('mean_squared_error', mean_squared_error), + ('mean_absolute_error', mean_absolute_error), + ] + + +def run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, metrics, args): + os.makedirs(oj(path, val_name, "rep" + str(i)), exist_ok=True) + np.random.seed(i) + max_iter = 100 + iter = 0 + while iter <= max_iter: # regenerate data if y is constant + X = X_dgp(**X_params_dict) + y, support, beta = y_dgp(X, **y_params_dict, return_support=True) + if not all(y == y[0]): + break + iter += 1 + if iter > max_iter: + raise ValueError("Response y is constant.") + if args.omit_vars is not None: + omit_vars = np.unique([int(x.strip()) for x in args.omit_vars.split(",")]) + support = np.delete(support, omit_vars) + X = np.delete(X, omit_vars, axis=1) + del beta # note: beta is not currently supported when using omit_vars + + for est in ests: + results = run_comparison(rep=i, + path=oj(path, val_name, "rep" + str(i)), + X=X, y=y, + metrics=metrics, + estimators=est, + args=args) + return True + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + default_dir = os.getenv("SCRATCH") + if default_dir is not None: + default_dir = oj(default_dir, "feature_importance", "results") + else: + default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results') + + parser.add_argument('--nreps', type=int, default=2) + parser.add_argument('--mode', type=str, default='regression') + parser.add_argument('--model', type=str, default=None) + parser.add_argument('--config', type=str, default='mdi_plus.prediction_sims.ccle_rnaseq_regression-') + parser.add_argument('--omit_vars', type=str, default=None) # comma-separated string of variables to omit + parser.add_argument('--nosave_cols', type=str, default=None) + + # for multiple reruns, should support varying split_seed + parser.add_argument('--ignore_cache', action='store_true', default=False) + parser.add_argument('--splitting_strategy', type=str, default="train-test") + parser.add_argument('--verbose', action='store_true', default=True) + parser.add_argument('--parallel', action='store_true', default=False) + parser.add_argument('--parallel_id', nargs='+', type=int, default=None) + parser.add_argument('--n_cores', type=int, default=None) + parser.add_argument('--split_seed', type=int, default=0) + parser.add_argument('--results_path', type=str, default=default_dir) + + args = parser.parse_args() + + assert args.splitting_strategy in { + 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata'} + assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'} + + if args.parallel: + if args.n_cores is None: + print(os.getenv("SLURM_CPUS_ON_NODE")) + n_cores = int(os.getenv("SLURM_CPUS_ON_NODE")) + else: + n_cores = args.n_cores + client = Client(n_workers=n_cores) + + ests, fi_ests, \ + X_dgp, X_params_dict, y_dgp, y_params_dict, \ + vary_param_name, vary_param_vals = fi_config.get_fi_configs(args.config) + metrics = get_metrics(args.mode) + + if args.model: + ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests)) + + if len(ests) == 0: + raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model) + if args.verbose: + print('running', args.config, + 'ests', ests) + print('\tsaving to', args.results_path) + + if args.omit_vars is not None: + results_dir = oj(args.results_path, args.config + "_omitted_vars") + else: + results_dir = oj(args.results_path, args.config) + + if isinstance(vary_param_name, list): + path = oj(results_dir, "varying_" + "_".join(vary_param_name), "seed" + str(args.split_seed)) + else: + path = oj(results_dir, "varying_" + vary_param_name, "seed" + str(args.split_seed)) + os.makedirs(path, exist_ok=True) + + eval_out = defaultdict(list) + + vary_type = None + if isinstance(vary_param_name, list): # multiple parameters are being varied + # get parameters that are being varied over and identify whether it's a DGP/method/fi_method argument + keys, values = zip(*vary_param_vals.items()) + vary_param_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)] + vary_type = {} + for vary_param_dict in vary_param_dicts: + for param_name, param_val in vary_param_dict.items(): + if param_name in X_params_dict.keys() and param_name in y_params_dict.keys(): + raise ValueError('Cannot vary over parameter in both X and y DGPs.') + elif param_name in X_params_dict.keys(): + vary_type[param_name] = "dgp" + X_params_dict[param_name] = vary_param_vals[param_name][param_val] + elif param_name in y_params_dict.keys(): + vary_type[param_name] = "dgp" + y_params_dict[param_name] = vary_param_vals[param_name][param_val] + else: + est_kwargs = list( + itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))])) + if param_name in est_kwargs: + vary_type[param_name] = "est" + else: + raise ValueError('Invalid vary_param_name.') + + if args.parallel: + futures = [ + dask.delayed(run_simulation)(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, + y_params_dict, y_dgp, ests, metrics, args) for i in + range(args.nreps)] + results = dask.compute(*futures) + else: + results = [ + run_simulation(i, path, "_".join(vary_param_dict.values()), X_params_dict, X_dgp, y_params_dict, + y_dgp, ests, metrics, args) for i in range(args.nreps)] + assert all(results) + + else: # only on parameter is being varied over + # get parameter that is being varied over and identify whether it's a DGP/method/fi_method argument + for val_name, val in vary_param_vals.items(): + if vary_param_name in X_params_dict.keys() and vary_param_name in y_params_dict.keys(): + raise ValueError('Cannot vary over parameter in both X and y DGPs.') + elif vary_param_name in X_params_dict.keys(): + vary_type = "dgp" + X_params_dict[vary_param_name] = val + elif vary_param_name in y_params_dict.keys(): + vary_type = "dgp" + y_params_dict[vary_param_name] = val + else: + est_kwargs = list(itertools.chain(*[list(est.kwargs.keys()) for est in list(itertools.chain(*ests))])) + if vary_param_name in est_kwargs: + vary_type = "est" + else: + raise ValueError('Invalid vary_param_name.') + + if args.parallel: + futures = [ + dask.delayed(run_simulation)(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, + metrics, args) for i in range(args.nreps)] + results = dask.compute(*futures) + else: + results = [run_simulation(i, path, val_name, X_params_dict, X_dgp, y_params_dict, y_dgp, ests, + metrics, args) for i in range(args.nreps)] + assert all(results) + + print('completed all experiments successfully!') + + # get model file names + model_comparison_files_all = [] + for est in ests: + estimator_name = est[0].name.split(' - ')[0] + model_comparison_file = f'{estimator_name}_comparisons.pkl' + model_comparison_files_all.append(model_comparison_file) + + # aggregate results + # aggregate results + results_list = [] + if isinstance(vary_param_name, list): + for vary_param_dict in vary_param_dicts: + val_name = "_".join(vary_param_dict.values()) + + for i in range(args.nreps): + all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(path, val_name, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + + for param_name, param_val in vary_param_dict.items(): + val = vary_param_vals[param_name][param_val] + if vary_type[param_name] == "dgp": + if np.isscalar(val): + results.insert(0, param_name, val) + else: + results.insert(0, param_name, [val for i in range(results.shape[0])]) + results.insert(1, param_name + "_name", param_val) + elif vary_type[param_name] == "est": + results.insert(0, param_name + "_name", copy.deepcopy(results[param_name])) + results.insert(0, 'rep', i) + results_list.append(results) + else: + for val_name, val in vary_param_vals.items(): + for i in range(args.nreps): + all_files = glob.glob(oj(path, val_name, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(path, val_name, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + if vary_type == "dgp": + if np.isscalar(val): + results.insert(0, vary_param_name, val) + else: + results.insert(0, vary_param_name, [val for i in range(results.shape[0])]) + results.insert(1, vary_param_name + "_name", val_name) + results.insert(2, 'rep', i) + elif vary_type == "est": + results.insert(0, vary_param_name + "_name", copy.deepcopy(results[vary_param_name])) + results.insert(1, 'rep', i) + results_list.append(results) + + results_merged = pd.concat(results_list, axis=0) + pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb')) + results_df = reformat_results(results_merged) + results_df.to_csv(oj(path, 'results.csv'), index=False) + + print('merged and saved all experiment results successfully!') + +# %% diff --git a/feature_importance/04_run_prediction_real_data.py b/feature_importance/04_run_prediction_real_data.py new file mode 100644 index 0000000..bb41e3c --- /dev/null +++ b/feature_importance/04_run_prediction_real_data.py @@ -0,0 +1,333 @@ +# Example usage: run in command line +# cd feature_importance/ +# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache +# python 03_run_real_data_prediction.py --nreps 2 --config test --split_seed 12345 --ignore_cache --create_rmd + +import copy +import os +from os.path import join as oj +import glob +import argparse +import pickle as pkl +import time +import warnings +from scipy import stats +import dask +from dask.distributed import Client +import numpy as np +import pandas as pd +from tqdm import tqdm +import sys +from collections import defaultdict +from typing import Callable, List, Tuple +import itertools +from functools import partial + +sys.path.append(".") +sys.path.append("..") +sys.path.append("../..") +import fi_config +from util import ModelConfig, apply_splitting_strategy, auroc_score, auprc_score + +from sklearn.metrics import accuracy_score, f1_score, recall_score, \ + precision_score, average_precision_score, r2_score, explained_variance_score, \ + mean_squared_error, mean_absolute_error, log_loss + +warnings.filterwarnings("ignore", message="Bins whose width") + + +def compare_estimators(estimators: List[ModelConfig], + X, y, + metrics: List[Tuple[str, Callable]], + args, rep) -> Tuple[dict, dict]: + """Calculates results given estimators, feature importance estimators, and datasets. + Called in run_comparison + """ + if type(estimators) != list: + raise Exception("First argument needs to be a list of Models") + if type(metrics) != list: + raise Exception("Argument metrics needs to be a list containing ('name', callable) pairs") + + # initialize results + results = defaultdict(lambda: []) + + if args.splitting_strategy is not None: + X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy( + X, y, args.splitting_strategy, args.split_seed + rep) + else: + X_train = X + X_tune = X + X_test = X + y_train = y + y_tune = y + y_test = y + + # loop over model estimators + for model in tqdm(estimators, leave=False): + est = model.cls(**model.kwargs) + + start = time.time() + est.fit(X_train, y_train) + end = time.time() + + metric_results = {'model': model.name} + y_pred = est.predict(X_test) + if args.mode != 'regression': + y_pred_proba = est.predict_proba(X_test) + if args.mode == 'binary_classification': + y_pred_proba = y_pred_proba[:, 1] + else: + y_pred_proba = y_pred + for met_name, met in metrics: + if met is not None: + if args.mode == 'regression' \ + or met_name in ['accuracy', 'f1', 'precision', 'recall']: + metric_results[met_name] = met(y_test, y_pred) + else: + metric_results[met_name] = met(y_test, y_pred_proba) + metric_results['predictions'] = copy.deepcopy(pd.DataFrame(y_pred_proba)) + metric_results['time'] = end - start + + # initialize results with metadata and metric results + kwargs: dict = model.kwargs # dict + for k in kwargs: + results[k].append(kwargs[k]) + for met_name, met_val in metric_results.items(): + results[met_name].append(met_val) + + return results + + +def run_comparison(rep: int, + path: str, + X, y, + metrics: List[Tuple[str, Callable]], + estimators: List[ModelConfig], + args): + estimator_name = estimators[0].name.split(' - ')[0] + model_comparison_file = oj(path, f'{estimator_name}_comparisons.pkl') + if args.parallel_id is not None: + model_comparison_file = f'_{args.parallel_id[0]}.'.join(model_comparison_file.split('.')) + + if os.path.isfile(model_comparison_file) and not args.ignore_cache: + print(f'{estimator_name} results already computed and cached. use --ignore_cache to recompute') + return + + results = compare_estimators(estimators=estimators, + X=X, y=y, + metrics=metrics, + args=args, + rep=rep) + + estimators_list = [e.name for e in estimators] + + df = pd.DataFrame.from_dict(results) + df['split_seed'] = args.split_seed + rep + if args.nosave_cols is not None: + nosave_cols = np.unique([x.strip() for x in args.nosave_cols.split(",")]) + else: + nosave_cols = [] + for col in nosave_cols: + if col in df.columns: + df = df.drop(columns=[col]) + + output_dict = { + # metadata + 'sim_name': args.config, + 'estimators': estimators_list, + + # actual values + 'df': df, + } + pkl.dump(output_dict, open(model_comparison_file, 'wb')) + + return df + + +def reformat_results(results): + results = results.reset_index().drop(columns=['index']) + predictions = pd.concat(results.pop('predictions').to_dict()). \ + reset_index(level=[0, 1]).rename(columns={'level_0': 'index', 'level_1': 'sample_id'}) + results_df = pd.merge(results, predictions, left_index=True, right_on="index") + return results_df + + +def get_metrics(mode: str = 'regression'): + if mode == 'binary_classification': + return [ + ('rocauc', auroc_score), + ('prauc', auprc_score), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', f1_score), + ('recall', recall_score), + ('precision', precision_score), + ('avg_precision', average_precision_score) + ] + elif mode == 'multiclass_classification': + return [ + ('rocauc', partial(auroc_score, multi_class="ovr")), + ('prauc', partial(auprc_score, multi_class="ovr")), + ('logloss', log_loss), + ('accuracy', accuracy_score), + ('f1', partial(f1_score, average='micro')), + ('recall', partial(recall_score, average='micro')), + ('precision', partial(precision_score, average='micro')) + ] + elif mode == 'regression': + return [ + ('r2', r2_score), + ('explained_variance', explained_variance_score), + ('mean_squared_error', mean_squared_error), + ('mean_absolute_error', mean_absolute_error), + ] + + +def run_simulation(i, path, Xpath, ypath, ests, metrics, args): + X_df = pd.read_csv(Xpath) + y_df = pd.read_csv(ypath) + if args.subsample_n is not None: + if args.subsample_n < X_df.shape[0]: + keep_rows = np.random.choice(X_df.shape[0], args.subsample_n, replace=False) + X_df = X_df.iloc[keep_rows] + y_df = y_df.iloc[keep_rows] + if args.response_idx is None: + keep_cols = y_df.columns + else: + keep_cols = [args.response_idx] + for col in keep_cols: + y = y_df[col].to_numpy().ravel() + keep_idx = ~pd.isnull(y) + X = X_df[keep_idx].to_numpy() + y = y[keep_idx] + if y_df.shape[1] > 1: + output_path = oj(path, col) + else: + output_path = path + os.makedirs(oj(output_path, "rep" + str(i)), exist_ok=True) + for est in ests: + for idx in range(len(est)): + if "random_state" in est[idx].kwargs.keys(): + est[idx].kwargs["random_state"] = i + results = run_comparison( + rep=i, + path=oj(output_path, "rep" + str(i)), + X=X, y=y, + metrics=metrics, + estimators=est, + args=args + ) + + return True + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + default_dir = os.getenv("SCRATCH") + if default_dir is not None: + default_dir = oj(default_dir, "feature_importance", "results") + else: + default_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results') + + parser.add_argument('--nreps', type=int, default=2) + parser.add_argument('--mode', type=str, default='regression') + parser.add_argument('--model', type=str, default=None) + parser.add_argument('--config', type=str, default='mdi_plus.prediction_sims.ccle_rnaseq_regression-') + parser.add_argument('--response_idx', type=str, default=None) + parser.add_argument('--subsample_n', type=int, default=None) + parser.add_argument('--nosave_cols', type=str, default=None) + + # for multiple reruns, should support varying split_seed + parser.add_argument('--ignore_cache', action='store_true', default=False) + parser.add_argument('--splitting_strategy', type=str, default="train-test") + parser.add_argument('--verbose', action='store_true', default=True) + parser.add_argument('--parallel', action='store_true', default=False) + parser.add_argument('--parallel_id', nargs='+', type=int, default=None) + parser.add_argument('--n_cores', type=int, default=None) + parser.add_argument('--split_seed', type=int, default=0) + parser.add_argument('--results_path', type=str, default=default_dir) + + args = parser.parse_args() + + assert args.splitting_strategy in { + 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata'} + assert args.mode in {'regression', 'binary_classification', 'multiclass_classification'} + + if args.parallel: + if args.n_cores is None: + print(os.getenv("SLURM_CPUS_ON_NODE")) + n_cores = int(os.getenv("SLURM_CPUS_ON_NODE")) + else: + n_cores = args.n_cores + client = Client(n_workers=n_cores) + + ests, _, Xpath, ypath = fi_config.get_fi_configs(args.config, real_data=True) + metrics = get_metrics(args.mode) + + if args.model: + ests = list(filter(lambda x: args.model.lower() == x[0].name.lower(), ests)) + + if len(ests) == 0: + raise ValueError('No valid estimators', 'sim', args.config, 'models', args.model) + if args.verbose: + print('running', args.config, + 'ests', ests) + print('\tsaving to', args.results_path) + + results_dir = oj(args.results_path, args.config) + path = oj(results_dir, "seed" + str(args.split_seed)) + os.makedirs(path, exist_ok=True) + + eval_out = defaultdict(list) + + if args.parallel: + futures = [dask.delayed(run_simulation)(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)] + results = dask.compute(*futures) + else: + results = [run_simulation(i, path, Xpath, ypath, ests, metrics, args) for i in range(args.nreps)] + assert all(results) + + print('completed all experiments successfully!') + + # get model file names + model_comparison_files_all = [] + for est in ests: + estimator_name = est[0].name.split(' - ')[0] + model_comparison_file = f'{estimator_name}_comparisons.pkl' + model_comparison_files_all.append(model_comparison_file) + + # aggregate results + y_df = pd.read_csv(ypath) + results_list = [] + for col in y_df.columns: + if y_df.shape[1] > 1: + output_path = oj(path, col) + else: + output_path = path + for i in range(args.nreps): + all_files = glob.glob(oj(output_path, 'rep' + str(i), '*')) + model_files = sorted([f for f in all_files if os.path.basename(f) in model_comparison_files_all]) + + if len(model_files) == 0: + print('No files found at ', oj(output_path, 'rep' + str(i))) + continue + + results = pd.concat( + [pkl.load(open(f, 'rb'))['df'] for f in model_files], + axis=0 + ) + results.insert(0, 'rep', i) + if y_df.shape[1] > 1: + results.insert(1, 'y_task', col) + results_list.append(results) + + results_merged = pd.concat(results_list, axis=0) + pkl.dump(results_merged, open(oj(path, 'results.pkl'), 'wb')) + results_df = reformat_results(results_merged) + results_df.to_csv(oj(path, 'results.csv'), index=False) + + print('merged and saved all experiment results successfully!') + +# %% diff --git a/feature_importance/__init__.py b/feature_importance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/__init__.py b/feature_importance/fi_config/__init__.py new file mode 100644 index 0000000..c4c5798 --- /dev/null +++ b/feature_importance/fi_config/__init__.py @@ -0,0 +1,13 @@ +import importlib + +def get_fi_configs(config_name, real_data=False): + if real_data: + ests = importlib.import_module(f'fi_config.{config_name}.models') + dgp = importlib.import_module(f'fi_config.{config_name}.dgp') + return ests.ESTIMATORS, ests.FI_ESTIMATORS, dgp.X_PATH, dgp.Y_PATH + else: + ests = importlib.import_module(f'fi_config.{config_name}.models') + dgp = importlib.import_module(f'fi_config.{config_name}.dgp') + return ests.ESTIMATORS, ests.FI_ESTIMATORS, \ + dgp.X_DGP, dgp.X_PARAMS_DICT, dgp.Y_DGP, dgp.Y_PARAMS_DICT, \ + dgp.VARY_PARAM_NAME, dgp.VARY_PARAM_VALS \ No newline at end of file diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..5e9f5e6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..0643de9 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = logistic_partial_linear_lss_model +Y_PARAMS_DICT = { + "s":1, + "m":3, + "r":2, + "tau":0, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_linear_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py new file mode 100644 index 0000000..babf294 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..5874330 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = logistic_lss_model +Y_PARAMS_DICT = { + "m": 3, + "r": 2, + "tau": 0, + "beta": 2, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0.25": 0.25, "0.15": 0.15, "0.05": 0.05, "0": None}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/ccle_rnaseq_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..16536f5 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..ab616ed --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = logistic_partial_linear_lss_model +Y_PARAMS_DICT = { + "s":1, + "m":3, + "r":2, + "tau":0, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_linear_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py new file mode 100644 index 0000000..9e43dad --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..e3646d1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = logistic_lss_model +Y_PARAMS_DICT = { + "m": 3, + "r": 2, + "tau": 0, + "beta": 2, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/enhancer_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..ec4e6c5 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..d88c1fa --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_partial_linear_lss_model +Y_PARAMS_DICT = { + "s":1, + "m":3, + "r":2, + "tau":0, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_linear_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py new file mode 100644 index 0000000..28b92fd --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..5688e4a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_lss_model +Y_PARAMS_DICT = { + "m": 3, + "r": 2, + "tau": 0, + "beta": 2, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/juvenile_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..0325f07 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..3caf738 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_partial_linear_lss_model +Y_PARAMS_DICT = { + "s":1, + "m":3, + "r":2, + "tau":0, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_linear_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py new file mode 100644 index 0000000..5db817e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..5df43a0 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_lss_model +Y_PARAMS_DICT = { + "m": 3, + "r": 2, + "tau": 0, + "beta": 2, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/classification_sims/splicing_lss_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/copy_models_config.sh new file mode 100644 index 0000000..d6beab6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/copy_models_config.sh @@ -0,0 +1,29 @@ +echo "Regression DGPs..." +for f in $(find . -type d -name "*_dgp" -o -name "*regression" ! -name "*logistic_dgp" ! -name "*classification" ! -name "*robust_dgp" ! -name "*-"); do + echo "$f" + cp models_regression.py "$f"/models.py +done + +echo "Classification DGPs..." +for f in $(find . -type d -name "*logistic_dgp" -o -name "*classification" ! -name "*-"); do + echo "$f" + cp models_classification.py "$f"/models.py +done + +echo "Robust DGPs..." +for f in $(find . -type d -name "*robust_dgp" ! -name "*-"); do + echo "$f" + cp models_robust.py "$f"/models.py +done + +echo "Bias Sims..." +f=mdi_bias_sims/entropy_sims/linear_dgp +echo "$f" +cp models_bias.py "$f"/models.py +f=mdi_bias_sims/entropy_sims/logistic_dgp +echo "$f" +cp models_bias.py "$f"/models.py +f=mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp +echo "$f" +cp models_bias.py "$f"/models.py + diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh new file mode 100644 index 0000000..5ea6241 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/copy_models_config.sh @@ -0,0 +1,5 @@ +echo "Modeling Choices Classification DGPs..." +for f in $(find . -type d -name "*_dgp"); do + echo "$f" + cp models.py "$f"/models.py +done diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..ec4e6c5 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..9dfc2e4 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM +from sklearn.metrics import roc_auc_score + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py new file mode 100644 index 0000000..28b92fd --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": None +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py new file mode 100644 index 0000000..9dfc2e4 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/juvenile_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM +from sklearn.metrics import roc_auc_score + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py new file mode 100644 index 0000000..9dfc2e4 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM +from sklearn.metrics import roc_auc_score + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py new file mode 100644 index 0000000..0325f07 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_hier_model +Y_PARAMS_DICT = { + "m":3, + "r":2, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py new file mode 100644 index 0000000..9dfc2e4 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_hier_poly_3m_2r_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM +from sklearn.metrics import roc_auc_score + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py new file mode 100644 index 0000000..5db817e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1500, + "sample_col_n": 100 +} +Y_DGP = logistic_model +Y_PARAMS_DICT = { + "s": 5, + "beta": 1, + "frac_label_corruption": None +} + +VARY_PARAM_NAME = ["frac_label_corruption", "sample_row_n"] +VARY_PARAM_VALS = {"frac_label_corruption": {"0": None, "0.05": 0.05, "0.15": 0.15, "0.25": 0.25}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py new file mode 100644 index 0000000..9dfc2e4 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/classification_sims/splicing_logistic_dgp/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.rf_plus import _fast_r2_score, _neg_log_loss +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM +from sklearn.metrics import roc_auc_score + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score, 'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_logloss', tree_mdi_plus, model_type='tree', other_params={'return_stability_scores': True})], + [FIModelConfig('MDI+_logistic_ridge_auroc', tree_mdi_plus, model_type='tree', other_params={'scoring_fns': roc_auc_score, 'return_stability_scores': True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c57710e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..e454519 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,22 @@ +import copy +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})] +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..193daff --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..e454519 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,22 @@ +import copy +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})] +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh new file mode 100644 index 0000000..9c837ca --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/copy_models_config.sh @@ -0,0 +1,5 @@ +echo "Modeling Choices Regression DGPs..." +for f in $(find . -type d -name "*_dgp"); do + echo "$f" + cp models.py "$f"/models.py +done diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..f328986 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": int(100 * 1.5), "250": int(250 * 1.5), "500": int(500 * 1.5), "1000": int(1000 * 1.5), "1500": int(1500 * 1.5)}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..e454519 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,22 @@ +import copy +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})] +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py new file mode 100644 index 0000000..0bed41f --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": int(100 * 1.5), "250": int(250 * 1.5), "500": int(500 * 1.5), "1000": int(1000 * 1.5), "1500": int(1500 * 1.5)}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py new file mode 100644 index 0000000..e454519 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/enhancer_linear_dgp/models.py @@ -0,0 +1,22 @@ +import copy +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})] +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py new file mode 100644 index 0000000..e454519 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_prediction_sims/models.py @@ -0,0 +1,22 @@ +import copy +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=27) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})] +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c57710e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..1cb8ff1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,23 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM +from sklearn.metrics import mean_absolute_error, r2_score + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})], + [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False, + other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')}, + "lasso": {"prediction_model": LassoRegressorPPM()}, + "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..193daff --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..1cb8ff1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,23 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM +from sklearn.metrics import mean_absolute_error, r2_score + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})], + [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False, + other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')}, + "lasso": {"prediction_model": LassoRegressorPPM()}, + "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh new file mode 100644 index 0000000..9c837ca --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/copy_models_config.sh @@ -0,0 +1,5 @@ +echo "Modeling Choices Regression DGPs..." +for f in $(find . -type d -name "*_dgp"); do + echo "$f" + cp models.py "$f"/models.py +done diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..5e3b23d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..1cb8ff1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,23 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM +from sklearn.metrics import mean_absolute_error, r2_score + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})], + [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False, + other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')}, + "lasso": {"prediction_model": LassoRegressorPPM()}, + "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py new file mode 100644 index 0000000..d09d33a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py new file mode 100644 index 0000000..1cb8ff1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/enhancer_linear_dgp/models.py @@ -0,0 +1,23 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM +from sklearn.metrics import mean_absolute_error, r2_score + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})], + [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False, + other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')}, + "lasso": {"prediction_model": LassoRegressorPPM()}, + "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py new file mode 100644 index 0000000..1cb8ff1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/glm_metric_choices_sims/regression_sims/models.py @@ -0,0 +1,23 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig, neg_mean_absolute_error +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap, tree_mdi_plus_ensemble +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM +from sklearn.metrics import mean_absolute_error, r2_score + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_ridge_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_r2', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'return_stability_scores': True})], + [FIModelConfig('MDI+_ridge_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen'), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_lasso_neg_mae', tree_mdi_plus, model_type='tree', other_params={'prediction_model': LassoRegressorPPM(), 'scoring_fns': neg_mean_absolute_error, 'return_stability_scores': True})], + [FIModelConfig('MDI+_ensemble', tree_mdi_plus_ensemble, model_type='tree', ascending=False, + other_params={"ridge": {"prediction_model": RidgeRegressorPPM(gcv_mode='eigen')}, + "lasso": {"prediction_model": LassoRegressorPPM()}, + "scoring_fns": {"r2": r2_score, "mae": neg_mean_absolute_error}})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py new file mode 100644 index 0000000..9b2fb69 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/dgp.py @@ -0,0 +1,29 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +n = 250 +d = 100 +n_correlated = 50 + +X_DGP = sample_block_cor_X +X_PARAMS_DICT = { + "n": n, + "d": d, + "rho": [0.8] + [0 for i in range(int(d / n_correlated - 1))], + "n_blocks": int(d / n_correlated) +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 1, + "m": 3, + "r": 2, + "tau": 0 +} + +VARY_PARAM_NAME = ["heritability", "rho"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.4": 0.4}, + "rho": {"0.5": [0.5] + [0 for i in range(int(d / n_correlated - 1))], "0.6": [0.6] + [0 for i in range(int(d / n_correlated - 1))], "0.7": [0.7] + [0 for i in range(int(d / n_correlated - 1))], "0.8": [0.8] + [0 for i in range(int(d / n_correlated - 1))], "0.9": [0.9] + [0 for i in range(int(d / n_correlated - 1))], "0.99": [0.99] + [0 for i in range(int(d / n_correlated - 1))]}} diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py new file mode 100644 index 0000000..cbfbc5b --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/correlation_sims/normal_block_cor_partial_linear_lss_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py new file mode 100644 index 0000000..9cadad5 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/dgp.py @@ -0,0 +1,20 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + + +X_DGP = entropy_X +X_PARAMS_DICT = { + "n": 100 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 1 +} + +VARY_PARAM_NAME = ["heritability", "n"] +VARY_PARAM_VALS = {"heritability": {"0.05": 0.05, "0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.6": 0.6, "0.8": 0.8}, + "n": {"50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000}} diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py new file mode 100644 index 0000000..cbfbc5b --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/linear_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py new file mode 100644 index 0000000..380d0ef --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/dgp.py @@ -0,0 +1,17 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + + +X_DGP = entropy_X +X_PARAMS_DICT = { + "n": 100 +} +Y_DGP = entropy_y +Y_PARAMS_DICT = { + "c": 3 +} + +VARY_PARAM_NAME = ["c", "n"] +VARY_PARAM_VALS = {"c": {"3": 3}, + "n": {"50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000}} diff --git a/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py new file mode 100644 index 0000000..f6c0b7d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/mdi_bias_sims/entropy_sims/logistic_dgp/models.py @@ -0,0 +1,24 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(gcv_mode='eigen'), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_ridge_inbag', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(loo=False, gcv_mode='eigen'), 'scoring_fns': _fast_r2_score, "sample_split": "inbag"})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_logistic_logloss_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": LogisticClassifierPPM(loo=False)})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] + diff --git a/feature_importance/fi_config/mdi_plus/models_bias.py b/feature_importance/fi_config/mdi_plus/models_bias.py new file mode 100644 index 0000000..cbfbc5b --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/models_bias.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_inbag', tree_mdi_plus, model_type='tree', other_params={"sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_inbag_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False, "sample_split": "inbag", "prediction_model": RidgeRegressorPPM(loo=False)})], + [FIModelConfig('MDI+_noraw', tree_mdi_plus, model_type='tree', other_params={"include_raw": False})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI_with_splits', tree_mdi, model_type='tree', other_params={"include_num_splits": True})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/models_classification.py b/feature_importance/fi_config/mdi_plus/models_classification.py new file mode 100644 index 0000000..45c16b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/models_classification.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': _fast_r2_score})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/models_regression.py b/feature_importance/fi_config/mdi_plus/models_regression.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/models_regression.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/models_robust.py b/feature_importance/fi_config/mdi_plus/models_robust.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/models_robust.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c57710e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..6071e78 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..193daff --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..6071e78 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh new file mode 100644 index 0000000..9c837ca --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/copy_models_config.sh @@ -0,0 +1,5 @@ +echo "Modeling Choices Regression DGPs..." +for f in $(find . -type d -name "*_dgp"); do + echo "$f" + cp models.py "$f"/models.py +done diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..5e3b23d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..6071e78 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py new file mode 100644 index 0000000..d09d33a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py new file mode 100644 index 0000000..6071e78 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/enhancer_linear_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py new file mode 100644 index 0000000..6071e78 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples1/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c57710e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..dffdf47 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..193daff --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..dffdf47 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh new file mode 100644 index 0000000..9c837ca --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/copy_models_config.sh @@ -0,0 +1,5 @@ +echo "Modeling Choices Regression DGPs..." +for f in $(find . -type d -name "*_dgp"); do + echo "$f" + cp models.py "$f"/models.py +done diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..5e3b23d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..dffdf47 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py new file mode 100644 index 0000000..d09d33a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py new file mode 100644 index 0000000..dffdf47 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/enhancer_linear_dgp/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py new file mode 100644 index 0000000..dffdf47 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/modeling_choices_min_samples5/models.py @@ -0,0 +1,18 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.ppms import RidgeRegressorPPM + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI+_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': False, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_raw', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'inbag', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=False, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI+_ridge_raw_loo', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeRegressorPPM(gcv_mode='eigen')})], + [FIModelConfig('MDI+_ols_raw_loo', tree_mdi_plus, model_type='tree', other_params={'sample_split': 'loo', 'include_raw': True, 'prediction_model': RidgeRegressorPPM(loo=True, alpha_grid=1e-6, gcv_mode='eigen')})], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..064576e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_col_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..cc56c30 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_col_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..a54aec8 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "sample_col_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..70768a7 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_col_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_col_n": {"10": 10, "25": 25, "50": 50, "100": 100, "250": 250, "500": 500, "1000": 1000, "2000": 2000}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_p/ccle_rnaseq_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..abbd688 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py new file mode 100644 index 0000000..99ae392 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "s"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "s": {"1": 1, "5": 5, "10": 10, "25": 25, "50": 50}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c23b7c2 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/__init__.py @@ -0,0 +1 @@ + diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..429916e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/juvenile_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..69e810c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py new file mode 100644 index 0000000..e049de9 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "s"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "s": {"1": 1, "5": 5, "10": 10, "25": 25, "50": 50}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..ab525ee --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..9611cb6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "m"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "m": {"2": 2, "3": 3, "5": 5, "7": 7, "10": 10}} diff --git a/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/other_regression_sims/varying_sparsity/splicing_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py new file mode 100644 index 0000000..b5860cd --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv" +Y_PATH = "data/y_ccle_rnaseq.csv" diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py new file mode 100644 index 0000000..dfe1a3e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/ccle_rnaseq_regression-/models.py @@ -0,0 +1,24 @@ +import copy +import numpy as np +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusRegressor +from imodels.importance.ppms import RidgeRegressorPPM, LassoRegressorPPM + + +rf_model = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=42) +ridge_model = RidgeRegressorPPM() +lasso_model = LassoRegressorPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})], + [ModelConfig('RF-ridge', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-lasso', RandomForestPlusRegressor, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(lasso_model)})], +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py new file mode 100644 index 0000000..89ec3e0 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_enhancer_all.csv" +Y_PATH = "data/y_enhancer.csv" diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py new file mode 100644 index 0000000..8ba5961 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/enhancer_classification-/models.py @@ -0,0 +1,26 @@ +import copy +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV +from sklearn.utils.extmath import softmax +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusClassifier +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM + + +rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42) +ridge_model = RidgeClassifierPPM() +logistic_model = LogisticClassifierPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})], + [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(logistic_model)})], +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py new file mode 100644 index 0000000..79ab2f0 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_juvenile_cleaned.csv" +Y_PATH = "data/y_juvenile.csv" diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py new file mode 100644 index 0000000..8ba5961 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/juvenile_classification-/models.py @@ -0,0 +1,26 @@ +import copy +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV +from sklearn.utils.extmath import softmax +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusClassifier +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM + + +rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42) +ridge_model = RidgeClassifierPPM() +logistic_model = LogisticClassifierPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})], + [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(logistic_model)})], +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py new file mode 100644 index 0000000..079b2da --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_splicing_cleaned.csv" +Y_PATH = "data/y_splicing.csv" diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py new file mode 100644 index 0000000..8ba5961 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/splicing_classification-/models.py @@ -0,0 +1,26 @@ +import copy +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV +from sklearn.utils.extmath import softmax +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusClassifier +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM + + +rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42) +ridge_model = RidgeClassifierPPM() +logistic_model = LogisticClassifierPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})], + [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(logistic_model)})], +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py new file mode 100644 index 0000000..490a687 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_tcga_cleaned.csv" +Y_PATH = "data/Y_tcga.csv" diff --git a/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py new file mode 100644 index 0000000..8ba5961 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/prediction_sims/tcga_brca_classification-/models.py @@ -0,0 +1,26 @@ +import copy +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import RidgeClassifierCV, LogisticRegressionCV +from sklearn.utils.extmath import softmax +from feature_importance.util import ModelConfig +from imodels.importance.rf_plus import RandomForestPlusClassifier +from imodels.importance.ppms import RidgeClassifierPPM, LogisticClassifierPPM + + +rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features='sqrt', random_state=42) +ridge_model = RidgeClassifierPPM() +logistic_model = LogisticClassifierPPM() + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})], + [ModelConfig('RF-ridge', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(ridge_model)})], + [ModelConfig('RF-logistic', RandomForestPlusClassifier, model_type='tree', + other_params={'rf_model': copy.deepcopy(rf_model), + 'prediction_model': copy.deepcopy(logistic_model)})], +] + +FI_ESTIMATORS = [] diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py new file mode 100644 index 0000000..b5860cd --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv" +Y_PATH = "data/y_ccle_rnaseq.csv" diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py new file mode 100644 index 0000000..4b1dbf8 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('MDI', tree_mdi, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('MDA', tree_mda, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree', splitting_strategy="train-test-prediction")] +] diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py new file mode 100644 index 0000000..490a687 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_tcga_cleaned.csv" +Y_PATH = "data/Y_tcga.csv" diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py new file mode 100644 index 0000000..061791f --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study/tcga_brca_classification-/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM +from functools import partial + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction", other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': partial(_fast_r2_score, multiclass=True)})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('MDI', tree_mdi, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('MDA', tree_mda, model_type='tree', splitting_strategy="train-test-prediction")], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree', splitting_strategy="train-test-prediction")] +] diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py new file mode 100644 index 0000000..b5860cd --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_ccle_rnaseq_cleaned_filtered5000.csv" +Y_PATH = "data/y_ccle_rnaseq.csv" diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py new file mode 100644 index 0000000..4cde78f --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/ccle_rnaseq_regression-/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/__init__.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py new file mode 100644 index 0000000..490a687 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/dgp.py @@ -0,0 +1,2 @@ +X_PATH = "data/X_tcga_cleaned.csv" +Y_PATH = "data/Y_tcga.csv" diff --git a/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py new file mode 100644 index 0000000..755fc01 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/real_data_case_study_no_data_split/tcga_brca_classification-/models.py @@ -0,0 +1,20 @@ +from sklearn.ensemble import RandomForestClassifier +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap +from imodels.importance.rf_plus import _fast_r2_score +from imodels.importance.ppms import RidgeClassifierPPM +from functools import partial + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestClassifier, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'random_state': 42})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree', other_params={'prediction_model': RidgeClassifierPPM(), 'scoring_fns': partial(_fast_r2_score, multiclass=True)})], + [FIModelConfig('MDI+_logistic_logloss', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..c57710e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py new file mode 100644 index 0000000..193daff --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..166f958 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..a1b17f6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/ccle_rnaseq_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..5e3b23d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py new file mode 100644 index 0000000..d09d33a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..919157d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..9c30166 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/enhancer_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..34a80c8 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py new file mode 100644 index 0000000..3a140eb --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..d8744e6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/__init__.py @@ -0,0 +1 @@ + diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..34d4f3f --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_juvenile_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": None +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/juvenile_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..6f2ef48 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/dgp.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_hier_poly_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py new file mode 100644 index 0000000..339a98d --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/dgp.py @@ -0,0 +1,21 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s":5, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..48ca787 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2, + "s": 1, +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_linear_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py new file mode 100644 index 0000000..7204b75 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/dgp.py @@ -0,0 +1,23 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": 1000, + "sample_col_n": 100 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "tau": 0, + "m": 3, + "r": 2 +} + +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000, "1500": 1500}} diff --git a/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py new file mode 100644 index 0000000..44f116c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/regression_sims/splicing_lss_3m_2r_dgp/models.py @@ -0,0 +1,17 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')], +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..74a7fb8 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,25 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..817e499 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,25 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..0a6db88 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 5, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..e6b0b6a --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 5, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..f7cc517 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,27 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "tau": 0, + "sigma": None, + "heritability": 0.4, + "s": 1, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..1a55ba0 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,27 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "tau": 0, + "sigma": None, + "heritability": 0.4, + "s": 1, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..c0634ea --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,26 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "tau": 0, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..b36e69c --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,26 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_ccle_rnaseq_cleaned.csv", + "sample_row_n": None, + "sample_col_n": 1000 +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "tau": 0, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "472": 472}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/ccle_rnaseq_lss_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..d64bf10 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,25 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..8f87a35 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,25 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = hierarchical_poly +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_hier_poly_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..864f016 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 5, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..74e744b --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/dgp.py @@ -0,0 +1,24 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 5, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..a17116e --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,27 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "tau": 0, + "sigma": None, + "heritability": 0.4, + "s": 1, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..8659075 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,27 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = partial_linear_lss_model +Y_PARAMS_DICT = { + "beta": 1, + "tau": 0, + "sigma": None, + "heritability": 0.4, + "s": 1, + "m": 3, + "r": 2, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_linear_lss_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py new file mode 100644 index 0000000..25ab03f --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/dgp.py @@ -0,0 +1,26 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "tau": 0, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 10 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_10MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py new file mode 100644 index 0000000..b3599de --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/dgp.py @@ -0,0 +1,26 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_enhancer_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = lss_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "m": 3, + "r": 2, + "tau": 0, + "corrupt_how": "leverage_normal", + "corrupt_size": 0.1, + "corrupt_mean": 25 +} + +VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"] +VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025}, + "sample_row_n": {"100": 100, "250": 250, "500": 500,"1000":1000,"1500":1500}} diff --git a/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py new file mode 100644 index 0000000..7d18df1 --- /dev/null +++ b/feature_importance/fi_config/mdi_plus/robust_sims/enhancer_lss_3m_2r_25MS_robust_dgp/models.py @@ -0,0 +1,22 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +from imodels.importance.ppms import RobustRegressorPPM, LassoRegressorPPM, huber_loss +from sklearn.metrics import mean_absolute_error + + +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+_ridge_loo_r2', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})], + [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/fi_config/test/__init__.py b/feature_importance/fi_config/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/fi_config/test/dgp.py b/feature_importance/fi_config/test/dgp.py new file mode 100644 index 0000000..6f9eb20 --- /dev/null +++ b/feature_importance/fi_config/test/dgp.py @@ -0,0 +1,31 @@ +import sys +sys.path.append("../..") +from feature_importance.scripts.simulations_util import * + + +X_DGP = sample_real_X +X_PARAMS_DICT = { + "fpath": "data/X_splicing_cleaned.csv", + "sample_row_n": None, + "sample_col_n": None +} +Y_DGP = linear_model +Y_PARAMS_DICT = { + "beta": 1, + "sigma": None, + "heritability": 0.4, + "s": 5 +} + +# # vary one parameter +# VARY_PARAM_NAME = "sample_row_n" +# VARY_PARAM_VALS = {"100": 100, "250": 250, "500": 500, "1000": 1000} + +# vary two parameters in a grid +VARY_PARAM_NAME = ["heritability", "sample_row_n"] +VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8}, + "sample_row_n": {"100": 100, "250": 250, "500": 500, "1000": 1000}} + +# # vary over n_estimators in RF model in models.py +# VARY_PARAM_NAME = "n_estimators" +# VARY_PARAM_VALS = {"placeholder": 0} diff --git a/feature_importance/fi_config/test/models.py b/feature_importance/fi_config/test/models.py new file mode 100644 index 0000000..6d3537c --- /dev/null +++ b/feature_importance/fi_config/test/models.py @@ -0,0 +1,19 @@ +from sklearn.ensemble import RandomForestRegressor +from feature_importance.util import ModelConfig, FIModelConfig +from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap + +# N_ESTIMATORS=[50, 100, 500, 1000] +ESTIMATORS = [ + [ModelConfig('RF', RandomForestRegressor, model_type='tree', + other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33})], + # [ModelConfig('RF', RandomForestRegressor, model_type='tree', vary_param="n_estimators", vary_param_val=m, + # other_params={'min_samples_leaf': 5, 'max_features': 0.33}) for m in N_ESTIMATORS] +] + +FI_ESTIMATORS = [ + [FIModelConfig('MDI+', tree_mdi_plus, model_type='tree')], + [FIModelConfig('MDI', tree_mdi, model_type='tree')], + [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')], + [FIModelConfig('MDA', tree_mda, model_type='tree')], + [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')] +] diff --git a/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd b/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd new file mode 100644 index 0000000..55a46ef --- /dev/null +++ b/feature_importance/notebooks/mdi_plus/01_paper_figures.Rmd @@ -0,0 +1,3205 @@ +--- +title: "MDI+ Simulation Results Summary" +author: "" +date: "`r format(Sys.time(), '%B %d, %Y')`" +output: vthemes::vmodern +params: + results_dir: + label: "Results directory" + value: "results/" + seed: + label: "Seed" + value: 12345 + for_paper: + label: "Export plots for paper" + value: FALSE + use_cached: + label: "Use cached .rds files" + value: TRUE + interactive: + label: "Interactive plots" + value: FALSE +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE) + +library(magrittr) +library(patchwork) +chunk_idx <- 1 + +# set parameters +results_dir <- params$results_dir +seed <- params$seed +tables_dir <- file.path("tables") +figures_dir <- file.path("figures") +figures_subdirs <- c("regression_sims", "classification_sims", "robust_sims", + "misspecified_regression_sims", "varying_sparsity", + "varying_p", "modeling_choices", "glm_metric_choices") +for (figures_subdir in figures_subdirs) { + if (!dir.exists(file.path(figures_dir, figures_subdir))) { + dir.create(file.path(figures_dir, figures_subdir), recursive = TRUE) + } +} + +# miscellaneous helper variables +heritabilities <- c(0.1, 0.2, 0.4, 0.8) +frac_label_corruptions <- c("0.25", "0.15", "0.05", "0") +corrupt_sizes <- c(0.05, 0.025, 0.01, 0) +mean_shifts <- c(10, 25) +metric <- "rocauc" + +# plot options +point_size <- 2 +line_size <- 1 +errbar_width <- 0 +if (params$interactive) { + plot_fun <- plotly::ggplotly +} else { + plot_fun <- function(x) x +} + +manual_color_palette_choices <- c( + "black", "black", "#9B5DFF", "blue", + "orange", "#71beb7", "#218a1e", "#cc3399" +) +show_methods_choices <- c( + "MDI+_ridge_RF", "MDI+_ridge_loo_r2_RF", "MDI+_logistic_logloss_RF", "MDI+_Huber_loo_huber_loss_RF", + "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF" +) +method_labels_choices <- c( + "MDI+ (ridge)", "MDI+ (ridge)", "MDI+ (logistic)", "MDI+ (Huber)", + "MDA", "MDI", "MDI-oob", "TreeSHAP" +) +color_df <- tibble::tibble( + color = manual_color_palette_choices, + name = show_methods_choices, + label = method_labels_choices +) + +manual_color_palette_all <- NULL +show_methods_all <- NULL +method_labels_all <- ggplot2::waiver() + +custom_theme <- vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 20, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + plot.title = ggplot2::element_blank() + # plot.title = ggplot2::element_text(size = 12, face = "plain", hjust = 0.5) +) +custom_theme_with_legend <- vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.title = ggplot2::element_text(size = 12, face = "plain"), + legend.text = ggplot2::element_text(size = 9), + legend.text.align = 0, + plot.title = ggplot2::element_blank() +) +custom_theme_without_legend <- vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.title = ggplot2::element_text(size = 12, face = "plain"), + legend.title = ggplot2::element_blank(), + legend.text = ggplot2::element_text(size = 9), + legend.text.align = 0, + plot.title = ggplot2::element_blank() +) + +fig_height <- 6 +fig_width <- 10 + +source("../../scripts/viz.R", chdir = TRUE) +``` + + +# Regression Simulations {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, rep(0.4, length(method_labels) - 1)) +legend_position <- c(0.73, 0.35) + +vary_param_name <- "heritability_sample_row_n" +y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing") + +remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile") +keep_legend_x_models <- c("enhancer") +keep_legend_y_models <- y_models + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s \n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "regression_sims", + sprintf("regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } +} + +``` + +```{r eval = TRUE, results = "asis"} +y_models <- c("lss_3m_2r", "hier_poly_3m_2r") +x_models <- c("splicing") + +remove_x_axis_models <- c("lss_3m_2r") +keep_legend_x_models <- x_models +keep_legend_y_models <- c("lss_3m_2r") + +if (params$for_paper) { + for (y_model in y_models) { + for (x_model in x_models) { + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (y_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "regression_sims", + sprintf("main_regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } + } +} + +``` + + +# Classification Simulations {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDI+_logistic_logloss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2)) +legend_position <- c(0.73, 0.4) + +vary_param_name <- "frac_label_corruption_sample_row_n" +y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic") +x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing") + +remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile") +keep_legend_x_models <- c("enhancer") +keep_legend_y_models <- y_models + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s \n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in frac_label_corruptions) { + plt <- results %>% + dplyr::filter(frac_label_corruption_name == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != frac_label_corruptions[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + height <- 2.72 + } else { + height <- 3 + } + if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "classification_sims", + sprintf("classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = "frac_label_corruption_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } +} +``` + +```{r eval = TRUE, results = "asis"} +y_models <- c("logistic", "linear_lss_3m_2r_logistic") +x_models <- c("ccle_rnaseq") + +remove_x_axis_models <- c("logistic") +keep_legend_x_models <- x_models +keep_legend_y_models <- c("logistic") + +if (params$for_paper) { + for (y_model in y_models) { + for (x_model in x_models) { + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + for (h in frac_label_corruptions) { + plt <- results %>% + dplyr::filter(frac_label_corruption_name == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != frac_label_corruptions[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (y_model %in% remove_x_axis_models) { + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + height <- 2.72 + } else { + height <- 3 + } + if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "classification_sims", + sprintf("main_classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } + } +} +``` + + +# Robust Simulations {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_loo_r2_RF", "MDI+_Huber_loo_huber_loss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2)) +legend_position <- c(0.73, 0.4) + +vary_param_name <- "corrupt_size_sample_row_n" +y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq") + +remove_x_axis_models <- c(10) +keep_legend_x_models <- c("enhancer", "ccle_rnaseq") +keep_legend_y_models <- c("linear", "linear_lss_3m_2r") +keep_legend_mean_shifts <- c(10) + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle} \n\n", x_model)) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + for (mean_shift in mean_shifts) { + cat(sprintf("\n\n#### Mean Shift = %s \n\n", mean_shift)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift) + fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC", + TRUE ~ metric + ) + if (params$for_paper) { + tmp <- results %>% + dplyr::filter(corrupt_size_name %in% corrupt_sizes, + method %in% show_methods) %>% + dplyr::group_by(sample_row_n, corrupt_size_name, method) %>% + dplyr::summarise(mean = mean(.data[[metric]])) + min_y <- min(tmp$mean) + max_y <- max(tmp$mean) + for (h in corrupt_sizes) { + plt <- results %>% + dplyr::filter(corrupt_size_name == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != corrupt_sizes[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (mean_shift %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == corrupt_sizes[length(corrupt_sizes)]) & + (mean_shift %in% keep_legend_mean_shifts) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "robust_sims", + sprintf("robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = "corrupt_size_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } + } +} + +``` + +```{r eval = TRUE, results = "asis"} +y_models <- c("lss_3m_2r") +x_models <- c("enhancer") + +remove_x_axis_models <- c(10) +keep_legend_x_models <- x_models +keep_legend_y_models <- y_models +keep_legend_mean_shifts <- c(10) + +if (params$for_paper) { + for (y_model in y_models) { + for (x_model in x_models) { + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + for (mean_shift in mean_shifts) { + plt_ls <- list() + sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift) + fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC", + TRUE ~ metric + ) + tmp <- results %>% + dplyr::filter(corrupt_size_name %in% corrupt_sizes, + method %in% show_methods) %>% + dplyr::group_by(sample_row_n, corrupt_size_name, method) %>% + dplyr::summarise(mean = mean(.data[[metric]])) + min_y <- min(tmp$mean) + max_y <- max(tmp$mean) + for (h in corrupt_sizes) { + plt <- results %>% + dplyr::filter(corrupt_size_name == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != corrupt_sizes[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (mean_shift %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == corrupt_sizes[length(corrupt_sizes)]) & + (mean_shift %in% keep_legend_mean_shifts) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "robust_sims", + sprintf("main_robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } + } + } +} + +``` + + +# Correlation Bias Simulations {.tabset .tabset-vmodern} + +## Main Figures {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("black", "#218a1e", "orange", "#71beb7", "#cc3399") +show_methods <- c("MDI+", "MDI-oob", "MDA", "MDI", "TreeSHAP") +method_labels <- c("MDI+ (ridge)", "MDI-oob", "MDA", "MDI", "TreeSHAP") + +custom_color_palette <- c("#38761d", "#9dc89b", "#991500") +keep_heritabilities <- c(0.1, 0.4) +sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_heritability_rho", + sprintf("seed%s", seed), + "results.csv" +) +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name %in% keep_heritabilities) + +n <- 250 +p <- 100 +sig_ids <- 0:5 +cnsig_ids <- 6:49 +nreps <- length(unique(results$rep)) + +#### Examine average rank across varying correlation levels #### +plt_base <- plot_perturbation_stability( + results, + facet_rows = "heritability_name", + facet_cols = "rho_name", + param_name = NULL, + sig_ids = sig_ids, + cnsig_ids = cnsig_ids, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +for (idx in 1:length(keep_heritabilities)) { + h <- keep_heritabilities[idx] + cat(sprintf("\n\n### PVE = %s\n\n", h)) + plt_df <- plt_base$agg[[idx]]$data + plt_ls <- list() + for (method in method_labels) { + plt_ls[[method]] <- plt_df %>% + dplyr::filter(fi == method) %>% + ggplot2::ggplot() + + ggplot2::aes(x = rho_name, y = .mean, color = group) + + ggplot2::geom_line() + + ggplot2::geom_ribbon( + ggplot2::aes(x = rho_name, + ymin = .mean - (.sd / sqrt(nreps)), + ymax = .mean + (.sd / sqrt(nreps)), + fill = group), + inherit.aes = FALSE, alpha = 0.2 + ) + + ggplot2::labs( + x = expression("Correlation ("*rho*")"), + y = "Average Rank", + color = "Feature\nGroup", + title = method + ) + + ggplot2::coord_cartesian( + ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3) + ) + + ggplot2::scale_color_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::scale_fill_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::guides(fill = "none", linetype = "none") + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + if (method != method_labels[1]) { + plt_ls[[method]] <- plt_ls[[method]] + + ggplot2::theme(axis.title.y = ggplot2::element_blank(), + axis.ticks.y = ggplot2::element_blank(), + axis.text.y = ggplot2::element_blank()) + } + } + + plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1) + if (params$for_paper) { + ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide.pdf", h)), + # height = fig_height * .55, width = fig_width * .25 * length(show_methods)) + height = fig_height * .55, width = fig_width * .3 * length(show_methods)) + } else { + vthemes::subchunkify( + plt_wide, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 + } +} + + +#### Examine number of splits across varying correlation levels #### +cat("\n\n### Number of RF Splits \n\n") +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name %in% keep_heritabilities, + fi == "MDI_with_splits") + +total_splits <- results %>% + dplyr::group_by(rho_name, heritability, rep) %>% + dplyr::summarise(n_splits = sum(av_splits)) + +plt_df <- results %>% + dplyr::left_join( + total_splits, by = c("rho_name", "heritability", "rep") + ) %>% + dplyr::mutate( + av_splits = av_splits / n_splits * 100, + group = dplyr::case_when( + var %in% sig_ids ~ "Sig", + var %in% cnsig_ids ~ "C-NSig", + TRUE ~ "NSig" + ) + ) %>% + dplyr::mutate( + group = factor(group, levels = c("Sig", "C-NSig", "NSig")) + ) %>% + dplyr::group_by(rho_name, heritability, group) %>% + dplyr::summarise( + .mean = mean(av_splits), + .sd = sd(av_splits) + ) + +min_y <- min(plt_df$.mean) +max_y <- max(plt_df$.mean) + +plt_ls <- list() +for (idx in 1:length(keep_heritabilities)) { + h <- keep_heritabilities[idx] + plt <- plt_df %>% + dplyr::filter(heritability == !!h) %>% + ggplot2::ggplot() + + ggplot2::aes(x = rho_name, y = .mean, color = group) + + ggplot2::geom_line(size = 1) + + # ggplot2::geom_ribbon( + # ggplot2::aes(x = rho_name, + # ymin = .mean - (.sd / sqrt(nreps)), + # ymax = .mean + (.sd / sqrt(nreps)), + # fill = group), + # inherit.aes = FALSE, alpha = 0.2 + # ) + + ggplot2::labs( + x = expression("Correlation ("*rho*")"), + y = "Percentage of Splits in RF (%)", + color = "Feature\nGroup", + title = sprintf("PVE = %s", h) + ) + + ggplot2::coord_cartesian(ylim = c(min_y, max_y)) + + ggplot2::scale_color_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::scale_fill_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::guides(fill = "none", linetype = "none") + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + if (idx != 1) { + plt <- plt + + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + plt_ls[[idx]] <- plt +} + +plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1) +if (params$for_paper) { + ggplot2::ggsave(plt_wide & ggplot2::theme(plot.title = ggplot2::element_blank()), + filename = file.path(figures_dir, "correlation_sim_num_splits.pdf"), + width = fig_width * 1, height = fig_height * .65) +} else { + vthemes::subchunkify( + plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 +} + +``` + +## Appendix Figures {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("black", "gray") +show_methods <- c("MDI+", "MDI+_inbag") +method_labels <- c("MDI+ (LOO)", "MDI+ (in-bag)") + +custom_color_palette <- c("#38761d", "#9dc89b", "#991500") +keep_heritabilities <- c(0.1, 0.4) +sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_heritability_rho", + sprintf("seed%s", seed), + "results.csv" +) +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name %in% keep_heritabilities) + +n <- 250 +p <- 100 +sig_ids <- 0:5 +cnsig_ids <- 6:49 +nreps <- length(unique(results$rep)) + +#### Examine average rank across varying correlation levels #### +plt_base <- plot_perturbation_stability( + results, + facet_rows = "heritability_name", + facet_cols = "rho_name", + param_name = NULL, + sig_ids = sig_ids, + cnsig_ids = cnsig_ids, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +for (idx in 1:length(keep_heritabilities)) { + h <- keep_heritabilities[idx] + cat(sprintf("\n\n### PVE = %s\n\n", h)) + plt_df <- plt_base$agg[[idx]]$data + plt_ls <- list() + for (method in method_labels) { + plt_ls[[method]] <- plt_df %>% + dplyr::filter(fi == method) %>% + ggplot2::ggplot() + + ggplot2::aes(x = rho_name, y = .mean, color = group) + + ggplot2::geom_line() + + ggplot2::geom_ribbon( + ggplot2::aes(x = rho_name, + ymin = .mean - (.sd / sqrt(nreps)), + ymax = .mean + (.sd / sqrt(nreps)), + fill = group), + inherit.aes = FALSE, alpha = 0.2 + ) + + ggplot2::labs( + x = expression("Correlation ("*rho*")"), + y = "Average Rank", + color = "Feature\nGroup", + title = method + ) + + ggplot2::coord_cartesian( + ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3) + ) + + ggplot2::scale_color_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::scale_fill_manual( + values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE) + ) + + ggplot2::guides(fill = "none", linetype = "none") + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + if (method != method_labels[1]) { + plt_ls[[method]] <- plt_ls[[method]] + + ggplot2::theme(axis.title.y = ggplot2::element_blank(), + axis.ticks.y = ggplot2::element_blank(), + axis.text.y = ggplot2::element_blank()) + } + } + + plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1) + if (params$for_paper) { + ggplot2::ggsave(plt_wide, + filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide_appendix.pdf", h)), + # height = fig_height * .55, width = fig_width * .25 * length(show_methods)) + height = fig_height * .55, width = fig_width * .37 * length(show_methods)) + } else { + vthemes::subchunkify( + plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 + } +} +``` + + +# Entropy Bias Simulations {.tabset .tabset-vmodern} + +## Main Figures {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +#### Entropy Regression Results #### +manual_color_palette <- rev(c("black", "orange", "#71beb7", "#218a1e", "#cc3399")) +show_methods <- rev(c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")) +method_labels <- rev(c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")) +alpha_values <- rev(c(1, rep(0.7, length(method_labels) - 1))) +y_limits <- c(1, 4.5) + +metric_name <- "AUROC" +sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_heritability_n", + sprintf("seed%s", seed), + "results.csv" +) + +#### Entropy Regression Avg Rank #### +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name == 0.1) + +custom_group_fun <- function(var) var + +plt_base <- plot_perturbation_stability( + results, + facet_rows = "heritability_name", + facet_cols = "n_name", + group_fun = custom_group_fun, + param_name = NULL, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +plt_df <- plt_base$agg[[1]]$data +nreps <- length(unique(results$rep)) + +plt_regression <- plt_df %>% + dplyr::filter(group == 0) %>% + ggplot2::ggplot() + + ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) + + ggplot2::geom_line(size = line_size) + + ggplot2::labs( + x = "Sample Size", + y = expression("Average Rank of "*X[1]), + color = "Method", + title = "Regression" + ) + + ggplot2::guides(fill = "none", alpha = "none") + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_alpha_manual(values = alpha_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::coord_cartesian(ylim = y_limits) + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + +#### Entropy Regression # Splits #### +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name == 0.1, + fi == "MDI_with_splits") + +nreps <- length(unique(results$rep)) + +total_splits <- results %>% + dplyr::group_by(n_name, heritability, rep) %>% + dplyr::summarise(n_splits = sum(av_splits)) + +regression_num_splits_df <- results %>% + dplyr::left_join( + total_splits, by = c("n_name", "heritability", "rep") + ) %>% + dplyr::mutate( + av_splits = av_splits / n_splits * 100 + ) %>% + dplyr::group_by(n_name, heritability, var) %>% + dplyr::summarise( + .mean = mean(av_splits), + .sd = sd(av_splits) + ) + +#### Entropy Classification Results #### +manual_color_palette <- rev(c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399")) +show_methods <- rev(c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP")) +method_labels <- rev(c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP")) +alpha_values <- rev(c(1, 1, rep(0.7, length(method_labels) - 2))) + +metric_name <- "AUROC" +sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_c_n", + sprintf("seed%s", seed), + "results.csv" +) + +#### Entropy Classification Avg Rank #### +results <- data.table::fread(fname) + +custom_group_fun <- function(var) var + +plt_base <- plot_perturbation_stability( + results, + facet_rows = "c_name", + facet_cols = "n_name", + group_fun = custom_group_fun, + param_name = NULL, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +plt_df <- plt_base$agg[[1]]$data +nreps <- length(unique(results$rep)) + +plt_classification <- plt_df %>% + dplyr::filter(group == 0) %>% + ggplot2::ggplot() + + ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) + + ggplot2::geom_line(size = line_size) + + ggplot2::labs( + x = "Sample Size", + y = expression("Average Rank of "*X[1]), + color = "Method", + title = "Classification" + ) + + ggplot2::guides(fill = "none", alpha = "none") + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_alpha_manual(values = alpha_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::coord_cartesian( + ylim = y_limits + ) + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + +#### Entropy Classification # Splits #### +results <- data.table::fread(fname) %>% + dplyr::filter(fi == "MDI_with_splits") + +total_splits <- results %>% + dplyr::group_by(n_name, c, rep) %>% + dplyr::summarise(n_splits = sum(av_splits)) + +classification_num_splits_df <- results %>% + dplyr::left_join( + total_splits, by = c("n_name", "c", "rep") + ) %>% + dplyr::mutate( + av_splits = av_splits / n_splits * 100 + ) %>% + dplyr::group_by(n_name, c, var) %>% + dplyr::summarise( + .mean = mean(av_splits), + .sd = sd(av_splits) + ) + +#### Show Avg Rank Plot #### +cat("\n\n### Average Rank \n\n") +plt <- plt_regression + plt_classification +if (params$for_paper) { + ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims.pdf"), + width = fig_width * 1.2, height = fig_height * .65) +} else { + vthemes::subchunkify( + plt, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 +} + +#### Show # Splits Plot #### +cat("\n\n### Number of RF Splits \n\n") +plt_df <- dplyr::bind_rows( + Regression = regression_num_splits_df, + Classification = classification_num_splits_df, + .id = "type" +) %>% + dplyr::mutate( + var = dplyr::case_when( + var == 0 ~ "Bernoulli (Signal)", + var == 1 ~ "Normal (Non-Signal)", + var == 2 ~ "4-Categories (Non-Signal)", + var == 3 ~ "10-Categories (Non-Signal)", + var == 4 ~ "20-Categories (Non-Signal)" + ) %>% + factor(levels = c("Normal (Non-Signal)", + "20-Categories (Non-Signal)", + "10-Categories (Non-Signal)", + "4-Categories (Non-Signal)", + "Bernoulli (Signal)")) + ) + +sim_types <- unique(plt_df$type) +plt_ls <- list() +for (idx in 1:length(sim_types)) { + sim_type <- sim_types[idx] + plt <- plt_df %>% + dplyr::filter(type == !!sim_type) %>% + ggplot2::ggplot() + + ggplot2::aes(x = n_name, y = .mean, color = as.factor(var)) + + ggplot2::geom_line(size = 1) + + ggplot2::labs( + x = "Sample Size", + y = "Percentage of Splits in RF (%)", + color = "Feature", + title = sim_type + ) + + ggplot2::guides(fill = "none", linetype = "none") + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5) + ) + if (idx != 1) { + plt <- plt + + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + plt_ls[[idx]] <- plt +} + +plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1) +if (params$for_paper) { + ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, "entropy_sims_num_splits.pdf"), + # height = fig_height * .55, width = fig_width * .25 * length(show_methods)) + width = fig_width * 1.1, height = fig_height * .65) +} else { + vthemes::subchunkify( + plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 +} + +``` + +## Appendix Figures {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +#### Entropy Regression Results #### +manual_color_palette <- rev(c("black", "black", "#71beb7")) +show_methods <- rev(c("MDI+", "MDI+_inbag", "MDI")) +method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI")) +alpha_values <- rev(c(1, 1, rep(1, length(method_labels) - 2))) +linetype_values <- c("solid", "dashed", "solid") +y_limits <- c(1, 5) + +metric_name <- "AUROC" +sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_heritability_n", + sprintf("seed%s", seed), + "results.csv" +) + +#### Entropy Regression Avg Rank #### +results <- data.table::fread(fname) %>% + dplyr::filter(heritability_name == 0.1) + +custom_group_fun <- function(var) var + +plt_base <- plot_perturbation_stability( + results, + facet_rows = "heritability_name", + facet_cols = "n_name", + group_fun = custom_group_fun, + param_name = NULL, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +plt_df <- plt_base$agg[[1]]$data +nreps <- length(unique(results$rep)) + +plt_regression <- plt_df %>% + dplyr::filter(group == 0) %>% + ggplot2::ggplot() + + ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) + + ggplot2::geom_line(size = line_size) + + ggplot2::labs( + x = "Sample Size", + y = expression("Average Rank of "*X[1]), + color = "Method", + linetype = "Method", + title = "Regression" + ) + + ggplot2::guides(fill = "none", alpha = "none") + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_alpha_manual(values = alpha_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_linetype_manual(values = linetype_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::coord_cartesian(ylim = y_limits) + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5), + legend.key.width = ggplot2::unit(1, "cm") + ) + +#### Entropy Classification Results #### +manual_color_palette <- rev(c("black", "black", "#9B5DFF", "#9B5DFF", "#71beb7")) +show_methods <- rev(c("MDI+_ridge", "MDI+_ridge_inbag", "MDI+_logistic_logloss", "MDI+_logistic_logloss_inbag", "MDI")) +method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI+ (logistic, LOO)", "MDI+ (logistic, in-bag)", "MDI")) +alpha_values <- rev(c(1, 1, 1, 1, rep(1, length(method_labels) - 4))) +linetype_values <- c("solid", "dashed", "solid", "dashed", "solid") + +metric_name <- "AUROC" +sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp" +fname <- file.path( + results_dir, + sim_name, + "varying_c_n", + sprintf("seed%s", seed), + "results.csv" +) + +#### Entropy Classification Avg Rank #### +results <- data.table::fread(fname) + +custom_group_fun <- function(var) var + +plt_base <- plot_perturbation_stability( + results, + facet_rows = "c_name", + facet_cols = "n_name", + group_fun = custom_group_fun, + param_name = NULL, + plot_types = "errbar", + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +plt_df <- plt_base$agg[[1]]$data +nreps <- length(unique(results$rep)) + +plt_classification <- plt_df %>% + dplyr::filter(group == 0) %>% + ggplot2::ggplot() + + ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) + + ggplot2::geom_line(size = line_size) + + ggplot2::labs( + x = "Sample Size", + y = expression("Average Rank of "*X[1]), + color = "Method", + linetype = "Method", + title = "Classification" + ) + + ggplot2::guides(fill = "none", alpha = "none") + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_alpha_manual(values = alpha_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::scale_linetype_manual(values = linetype_values, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + + ggplot2::coord_cartesian( + ylim = y_limits + ) + + vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"), + axis.text = ggplot2::element_text(size = 14), + axis.title = ggplot2::element_text(size = 18, face = "plain"), + legend.text = ggplot2::element_text(size = 14), + legend.title = ggplot2::element_text(size = 18), + plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5), + legend.key.width = ggplot2::unit(1, "cm") + ) + +#### Show Avg Rank Plot #### +cat("\n\n### Average Rank \n\n") +plt <- plt_regression + plt_classification +if (params$for_paper) { + ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims_appendix.pdf"), + width = fig_width * 1.35, height = fig_height * .65) +} else { + vthemes::subchunkify( + plt, i = chunk_idx, fig_height = fig_height * .7, fig_width = fig_width * 1.3 + ) + chunk_idx <- chunk_idx + 1 +} + +``` + + +# CCLE Case Study {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399") +show_methods <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP") +method_labels <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP") +alpha_values <- c(1, rep(0.4, length(method_labels) - 1)) + +fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-" +fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" +) +pred_fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "pred_results.csv" +) +X <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv") +X_columns <- colnames(X) +results <- data.table::fread(fname) +pred_results <- data.table::fread(pred_fname) %>% + tibble::as_tibble() %>% + dplyr::mutate(method = fi) %>% + # tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>% + tidyr::pivot_wider(names_from = metric, values_from = metric_value) + +out <- plot_top_stability( + results, + group_id = "y_task", + top_r = 10, + show_max_features = 5, + varnames = colnames(X), + base_method = "MDI+ (ridge)", + return_df = TRUE, + manual_color_palette = rev(manual_color_palette), + show_methods = rev(show_methods), + method_labels = rev(method_labels) +) + +ranking_df <- out$rankings +stability_df <- out$stability +plt_ls <- out$plot_ls + +# get gene names instead of ENSG ids +library(EnsDb.Hsapiens.v79) +varnames <- ensembldb::select( + EnsDb.Hsapiens.v79, + keys = stringr::str_remove(colnames(X), "\\.[^\\.]+$"), + keytype = "GENEID", columns = c("GENEID","SYMBOL") +) + +# get top 5 genes +top_genes <- ranking_df %>% + dplyr::group_by(y_task, fi, var) %>% + dplyr::summarise(mean_rank = mean(rank)) %>% + dplyr::ungroup() %>% + dplyr::arrange(y_task, mean_rank) %>% + dplyr::group_by(y_task, fi) %>% + dplyr::mutate(rank = 1:dplyr::n()) %>% + dplyr::ungroup() + +top_genes_table <- top_genes %>% + dplyr::filter(rank <= 50) %>% + dplyr::left_join( + y = data.frame(var = 0:(ncol(X) - 1), + GENEID = stringr::str_remove(colnames(X), "\\.[^\\.]+$")), + by = "var" + ) %>% + dplyr::left_join( + y = varnames, by = "GENEID" + ) %>% + dplyr::mutate( + value = sprintf("%s (%s)", SYMBOL, round(mean_rank, 2)) + ) %>% + tidyr::pivot_wider( + id_cols = c(y_task, rank), + names_from = fi, + values_from = value + ) %>% + dplyr::rename( + "Drug" = "y_task", + "Rank" = "rank" + ) + +# write.csv( +# top_genes_table, file.path(tables_dir, "ccle_rnaseq_top_genes_table.csv"), +# row.names = FALSE, quote = FALSE +# ) + +# get # genes with perfect stability +stable_genes_kable <- stability_df %>% + dplyr::group_by(fi, y_task) %>% + dplyr::summarise(stability_1 = sum(stability_score == 1)) %>% + tidyr::pivot_wider(id_cols = fi, names_from = y_task, values_from = stability_1) %>% + dplyr::ungroup() %>% + as.data.frame() + +pred_plt_ls <- list() +for (plt_id in names(plt_ls)) { + cat(sprintf("\n\n## %s \n\n", plt_id)) + plt <- plt_ls[[plt_id]] + if (plt_id == "Summary") { + plt <- plt + ggplot2::labs(y = "Drug") + if (params$for_paper) { + ggplot2::ggsave(plt, filename = file.path(figures_dir, "ccle_rnaseq_stability_top10.pdf"), + width = fig_width * 1.5, height = fig_height * 1.1) + } else { + vthemes::subchunkify( + plt, i = sprintf("%s-%s", fpath, plt_id), + fig_height = fig_height, fig_width = fig_width * 1.3 + ) + } + + cat("\n\n## Table of Top Genes \n\n") + if (params$for_paper) { + top_genes_kable <- top_genes_table %>% + dplyr::filter(Rank <= 5) %>% + dplyr::select(-Drug) %>% + dplyr::rename(" " = "Rank") %>% + vthemes::pretty_kable(format = "latex", longtable = TRUE) %>% + kableExtra::kable_styling(latex_options = "repeat_header") %>% + # kableExtra::collapse_rows(columns = 1, valign = "middle") + kableExtra::pack_rows( + index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug)) + ) + } else { + top_genes_kable <- top_genes_table %>% + dplyr::filter(Rank <= 5) %>% + dplyr::select(-Drug) %>% + dplyr::rename(" " = "Rank") %>% + vthemes::pretty_kable() %>% + kableExtra::kable_styling(latex_options = "repeat_header") %>% + # kableExtra::collapse_rows(columns = 1, valign = "middle") + kableExtra::pack_rows( + index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug)) + ) + vthemes::subchunkify( + top_genes_kable, i = sprintf("%s-table", fpath) + ) + } + } else { + pred_plt_ls[[plt_id]] <- pred_results %>% + dplyr::filter(k <= 10, y_task == !!plt_id) %>% + plot_metrics( + metric = "r2", + x_str = "k", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + inside_legend = FALSE + ) + + ggplot2::labs(x = "Number of Top Features (k)", y = "Test R-squared", + color = "Method", alpha = "Method", title = plt_id) + + vthemes::theme_vmodern() + + if (!params$for_paper) { + vthemes::subchunkify( + plot_fun(plt), i = sprintf("%s-%s", fpath, plt_id), + fig_height = fig_height, fig_width = fig_width * 1.3 + ) + vthemes::subchunkify( + plot_fun(pred_plt_ls[[plt_id]]), i = sprintf("%s-%s-pred", fpath, plt_id), + fig_height = fig_height, fig_width = fig_width * 1.3 + ) + } + } +} + +if (params$for_paper) { + all_plt_ls <- list() + for (idx in 1:length(pred_plt_ls)) { + drug <- names(pred_plt_ls)[idx] + plt <- pred_plt_ls[[drug]] + if ((idx - 1) %% 4 != 0) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if ((idx - 1) %/% 4 != 5) { + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } + all_plt_ls[[drug]] <- plt + } + plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) & + ggplot2::theme( + plot.title = ggplot2::element_text(hjust = 0.5), + panel.background = ggplot2::element_rect(fill = "white"), + panel.grid.major = ggplot2::element_line(color = "white") + ) + ggplot2::ggsave(plt, + filename = file.path(figures_dir, "ccle_rnaseq_predictions.pdf"), + width = fig_width * 1, height = fig_height * 2) + + pred10_ranks <- pred_results %>% + dplyr::filter(k == 10) %>% + dplyr::group_by(fi, y_task) %>% + dplyr::summarise(r2 = mean(r2)) %>% + tidyr::pivot_wider(names_from = y_task, values_from = r2) %>% + dplyr::ungroup() %>% + dplyr::mutate(dplyr::across(`17-AAG`:`ZD-6474`, ~rank(-.x))) + pred10_ranks_table <- apply(pred10_ranks, 1, FUN = function(x) table(x[-1])) %>% + tibble::as_tibble() %>% + tibble::rownames_to_column("Rank") %>% + setNames(c("Rank", pred10_ranks$fi)) %>% + dplyr::select(Rank, `MDI+`, TreeSHAP, MDI, `MDI-oob`, MDA) %>% + vthemes::pretty_kable(format = "latex") + top_genes_kable <- top_genes_table %>% + dplyr::filter(Rank <= 5) %>% + dplyr::select(-Drug) %>% + dplyr::rename(" " = "Rank") %>% + vthemes::pretty_kable() %>% + kableExtra::kable_styling(latex_options = "repeat_header") %>% + # kableExtra::collapse_rows(columns = 1, valign = "middle") + kableExtra::pack_rows( + index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug)) + ) + vthemes::subchunkify( + top_genes_kable, i = sprintf("%s-table", fpath) + ) +} +``` + +```{r eval = TRUE, results = "asis"} +if (params$for_paper) { + manual_color_palette_full <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399") + show_methods_full <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP") + method_labels_full <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP") + keep_methods_list <- list( + all = show_methods_full, + no_mdioob = c("MDI+", "MDA", "MDI", "TreeSHAP"), + zoom = c("MDI+", "MDI", "TreeSHAP") + ) + + fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-" + fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" + ) + X <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv") + results <- data.table::fread(fname) + + for (id in names(keep_methods_list)) { + keep_methods <- keep_methods_list[[id]] + manual_color_palette <- manual_color_palette_full[show_methods_full %in% keep_methods] + show_methods <- show_methods_full[show_methods_full %in% keep_methods] + method_labels <- method_labels_full[show_methods_full %in% keep_methods] + + plt_ls <- plot_top_stability( + results, + group_id = "y_task", + top_r = 10, + show_max_features = 5, + varnames = colnames(X), + base_method = "MDI+", + return_df = FALSE, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels + ) + all_plt_ls <- list() + for (idx in 1:length(unique(results$y_task))) { + drug <- sort(unique(results$y_task))[idx] + plt <- plt_ls[[drug]] + if ((idx - 1) %% 4 != 0) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if ((idx - 1) %/% 4 != 5) { + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } + all_plt_ls[[drug]] <- plt + } + plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) & + ggplot2::theme( + plot.title = ggplot2::element_text(hjust = 0.5), + panel.background = ggplot2::element_rect(fill = "white"), + panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25)) + ) + ggplot2::ggsave(plt, + filename = file.path(figures_dir, sprintf("ccle_rnaseq_stability_select_features_%s.pdf", id)), + width = fig_width * 1, height = fig_height * 2) + } +} +``` + + +# TCGA BRCA Case Study {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#cc3399") +show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP") +method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "TreeSHAP") +alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2)) + +fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-" +fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" +) +X <- data.table::fread("../../X_tcga_cleaned.csv") +results <- data.table::fread(fname) +out <- plot_top_stability( + results, + group_id = NULL, + top_r = 10, + show_max_features = 20, + varnames = colnames(X), + base_method = "MDI+ (ridge)", + # descending_methods = "MDI+ (logistic)", + return_df = TRUE, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels +) + +ranking_df <- out$rankings +stability_df <- out$stability +plt_ls <- out$plot_ls + +cat("\n\n## Non-zero Stability Scores Per Method\n\n") +vthemes::subchunkify( + plt_ls[[1]], i = sprintf("%s-%s", fpath, 1), + fig_height = fig_height * round(length(unique(results$fi)) / 2), fig_width = fig_width +) + +cat("\n\n## Stability of Top Features\n\n") +vthemes::subchunkify( + plot_fun(plt_ls[[2]]), i = sprintf("%s-%s", fpath, 2), + fig_height = fig_height, fig_width = fig_width +) + +# get top 25 genes +top_genes <- ranking_df %>% + dplyr::group_by(fi, var) %>% + dplyr::summarise(mean_rank = mean(rank)) %>% + dplyr::ungroup() %>% + dplyr::arrange(mean_rank) %>% + dplyr::group_by(fi) %>% + dplyr::mutate(rank = 1:dplyr::n()) %>% + dplyr::ungroup() + +top_genes_table <- top_genes %>% + dplyr::filter(rank <= 25) %>% + dplyr::left_join( + y = data.frame(var = 0:(ncol(X) - 1), + gene = colnames(X)), + by = "var" + ) %>% + dplyr::mutate( + value = sprintf("%s (%s)", gene, round(mean_rank, 2)) + ) %>% + tidyr::pivot_wider( + id_cols = rank, + names_from = fi, + values_from = value + ) %>% + dplyr::rename( + "Rank" = "rank" + ) + +# write.csv( +# top_genes_table, file.path(tables_dir, "tcga_top_genes_table.csv"), +# row.names = FALSE, quote = FALSE +# ) + +stable_genes_kable <- stability_df %>% + dplyr::group_by(fi) %>% + dplyr::summarise(stability_1 = sum(stability_score == 1)) %>% + tibble::as_tibble() + +cat("\n\n## Table of Top Genes \n\n") +if (params$for_paper) { + top_genes_kable <- top_genes_table %>% + vthemes::pretty_kable(format = "latex")#, longtable = TRUE) %>% + # kableExtra::kable_styling(latex_options = "repeat_header") +} else { + top_genes_kable <- top_genes_table %>% + vthemes::pretty_kable() %>% + kableExtra::kable_styling(latex_options = "repeat_header") + vthemes::subchunkify( + top_genes_kable, i = sprintf("%s-table", fpath) + ) +} + +cat("\n\n## Prediction using Top Features\n\n") +pred_fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "pred_results.csv" +) +pred_results <- data.table::fread(pred_fname) %>% + tibble::as_tibble() %>% + dplyr::mutate(method = fi) %>% + tidyr::pivot_wider(names_from = metric, values_from = metric_value) +plt <- pred_results %>% + dplyr::filter(k <= 25) %>% + plot_metrics( + metric = c("rocauc", "accuracy"), + x_str = "k", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme, + inside_legend = FALSE + ) + + ggplot2::labs(x = "Number of Top Features (k)", y = "Mean Prediction Performance", + color = "Method", alpha = "Method") +if (params$for_paper) { + ggplot2::ggsave(plt, filename = file.path(figures_dir, "tcga_brca_prediction.pdf"), + width = fig_width * 1.3, height = fig_height * .85) +} else { + vthemes::subchunkify( + plot_fun(plt), i = sprintf("%s-pred", fpath), + fig_height = fig_height * .8, fig_width = fig_width * 1.5 + ) +} +``` + +```{r eval = TRUE, results = "asis"} +if (params$for_paper) { + manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399") + show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP") + method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP") + + # ccle rnaseq + fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-" + fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" + ) + X_ccle <- data.table::fread("../../data/X_ccle_rnaseq_cleaned_filtered5000.csv") + results_ccle <- data.table::fread(fname) %>% + dplyr::mutate( + fi = ifelse(fi == "MDI+", "MDI+_ridge", fi) + ) + + # tcga + fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-" + fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" + ) + X_tcga <- data.table::fread("../../data/X_tcga_cleaned.csv") + results_tcga <- data.table::fread(fname) + + # join results + keep_cols <- c("rep", "y_task", "model", "fi", "index", "var", "importance") + results <- dplyr::bind_rows( + results_ccle %>% dplyr::select(tidyselect::all_of(keep_cols)), + results_tcga %>% dplyr::mutate(y_task = "TCGA-BRCA") %>% dplyr::select(tidyselect::all_of(keep_cols)) + ) %>% + tibble::as_tibble() + plt_ls <- plot_top_stability( + results, + group_id = "y_task", + top_r = 10, + show_max_features = 5, + base_method = "MDI+ (ridge)", + return_df = FALSE, + manual_color_palette = rev(manual_color_palette), + show_methods = rev(show_methods), + method_labels = rev(method_labels) + ) + + plt <- plt_ls$Summary + + ggplot2::labs( + y = "Drug", + x = "Number of Distinct Features in Top 10 Across 32 Training-Test Splits" + ) + + ggplot2::scale_x_continuous( + n.breaks = 6 + ) + ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_top10.pdf"), + width = fig_width * 1.5, height = fig_height * 1.2) + + keep_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP") + manual_color_palette_small <- manual_color_palette[show_methods %in% keep_methods] + show_methods_small <- show_methods[show_methods %in% keep_methods] + method_labels_small <- method_labels[show_methods %in% keep_methods] + + plt_tcga <- plot_top_stability( + results_tcga, + top_r = 10, + show_max_features = 5, + base_method = "MDI+ (ridge)", + return_df = FALSE, + manual_color_palette = manual_color_palette_small, + show_methods = show_methods_small, + method_labels = method_labels_small + )$`Stability of Top Features` + + ggplot2::labs(title = "TCGA-BRCA") + + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + + plt_ls <- plot_top_stability( + results_ccle, + group_id = "y_task", + top_r = 10, + show_max_features = 5, + # varnames = colnames(X), + base_method = "MDI+ (ridge)", + return_df = FALSE, + manual_color_palette = manual_color_palette_small[show_methods_small != "MDI+_logistic_logloss"], + show_methods = show_methods_small[show_methods_small != "MDI+_logistic_logloss"], + method_labels = method_labels_small[show_methods_small != "MDI+_logistic_logloss"] + ) + + # keep_drugs <- c("Panobinostat", "Lapatinib", "L-685458") + keep_drugs <- c("PD-0325901", "Panobinostat", "L-685458") + small_plt_ls <- list() + for (drug in keep_drugs) { + plt <- plt_ls[[drug]] + if (drug != keep_drugs[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + small_plt_ls[[drug]] <- plt + + ggplot2::guides(fill = "none") + } + small_plt_ls[["TCGA-BRCA"]] <- plt_tcga + # small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + small_plt_ls[[4]] + + # patchwork::plot_layout(nrow = 1, ncol = 4, guides = "collect") + plt <- patchwork::wrap_plots(small_plt_ls, nrow = 1, ncol = 4, guides = "collect") & + ggplot2::theme( + plot.title = ggplot2::element_text(hjust = 0.5), + panel.background = ggplot2::element_rect(fill = "white"), + panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25)) + ) + ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_select_features.pdf"), + width = fig_width * 1, height = fig_height * 0.5) + + plt <- small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + patchwork::plot_spacer() + small_plt_ls[[4]] + + patchwork::plot_layout(nrow = 1, widths = c(1, 1, 1, 0.1, 1), guides = "collect") + # plt + # grid::grid.draw( + # grid::linesGrob(x = grid::unit(c(0.68, 0.68), "npc"), + # y = grid::unit(c(0.06, 0.98), "npc")) + # ) + # # ggplot2::ggsave( + # # file.path(figures_dir, "casestudy_stability_select_features.pdf"), + # # units = "in", width = 14, height = 4 + # # ) +} + +``` + + +# Misspecified Models {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, rep(0.4, length(method_labels) - 1)) +legend_position <- c(0.73, 0.35) + +vary_param_name <- "heritability_sample_row_n" +y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing") + +remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile") +keep_legend_x_models <- c("enhancer") +keep_legend_y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s \n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp_omitted_vars", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "misspecified_regression_sims", + sprintf("misspecified_regression_sims_%s_%s_%s_errbars.pdf", + y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } +} + +``` + + +# Varying Sparsity {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, rep(0.4, length(method_labels) - 1)) +legend_position <- c(0.73, 0.7) + +y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +x_models <- c("juvenile", "splicing") + +remove_x_axis_models <- c("splicing") +keep_legend_x_models <- c("splicing") +keep_legend_y_models <- c("linear", "linear_lss_3m_2r") + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s \n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + x_str <- ifelse(y_model == "linear", "s", "m") + vary_param_name <- sprintf("heritability_%s", x_str) + fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_sparsity.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = x_str, + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = toupper(x_str), y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "varying_sparsity", + sprintf("regression_sims_sparsity_%s_%s_%s_errbars.pdf", + y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = 3 + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = x_str, + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = toupper(x_str), y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } +} + +``` + + +# Varying \# Features {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +keep_methods <- color_df$name %in% c("MDI+_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF") +manual_color_palette <- color_df$color[keep_methods] +show_methods <- color_df$name[keep_methods] +method_labels <- color_df$label[keep_methods] +alpha_values <- c(1, rep(0.4, length(method_labels) - 1)) +legend_position <- c(0.73, 0.4) + +vary_param_name <- "heritability_sample_col_n" +y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +x_models <- c("ccle_rnaseq") + +remove_x_axis_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r") +keep_legend_x_models <- c("ccle_rnaseq") +keep_legend_y_models <- c("linear") + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s \n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_p.", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_col_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Number of Features", y = metric_name, + color = "Method", alpha = "Method" + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (y_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "varying_p", + sprintf("regression_sims_vary_p_%s_%s_%s_errbars.pdf", + y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_col_n", + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Number of Features", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } +} + +``` + + +# Prediction Results {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +fpaths <- c( + "mdi_plus.prediction_sims.ccle_rnaseq_regression-", + "mdi_plus.prediction_sims.enhancer_classification-", + "mdi_plus.prediction_sims.splicing_classification-", + "mdi_plus.prediction_sims.juvenile_classification-", + "mdi_plus.prediction_sims.tcga_brca_classification-" +) + +keep_models <- c("RF", "RF-ridge", "RF-lasso", "RF-logistic") +prediction_metrics <- c( + "r2", "explained_variance", "mean_squared_error", "mean_absolute_error", + "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss" +) + +results_ls <- list() +for (fpath in fpaths) { + fname <- file.path( + results_dir, + fpath, + sprintf("seed%s", seed), + "results.csv" + ) + results <- data.table::fread(fname) %>% + reformat_results(prediction = TRUE) + + if (!("y_task" %in% colnames(results))) { + results <- results %>% + dplyr::mutate(y_task = "Results") + } + + plt_df <- results %>% + tidyr::pivot_longer( + cols = tidyselect::any_of(prediction_metrics), + names_to = "metric", values_to = "value" + ) %>% + dplyr::group_by(y_task, model, metric) %>% + dplyr::summarise( + mean = mean(value), + sd = sd(value), + n = dplyr::n() + ) %>% + dplyr::mutate( + se = sd / sqrt(n) + ) + + if (fpath == "mdi_plus.prediction_sims.ccle_rnaseq_regression") { + tab <- plt_df %>% + dplyr::mutate( + value = sprintf("%.3f (%.3f)", mean, se) + ) %>% + tidyr::pivot_wider(id_cols = model, names_from = "metric", values_from = "value") %>% + dplyr::arrange(dplyr::desc(accuracy)) %>% + dplyr::select(model, accuracy, f1, rocauc, prauc, precision, recall) %>% + vthemes::pretty_kable(format = "latex") + } + + if (fpath != "mdi_plus.prediction_sims.ccle_rnaseq_regression-") { + results_ls[[fpath]] <- plt_df %>% + dplyr::filter(model %in% c("RF", "RF-logistic")) %>% + dplyr::select(model, metric, mean) %>% + tidyr::pivot_wider(id_cols = metric, names_from = "model", values_from = "mean") %>% + dplyr::mutate(diff = `RF-logistic` - RF, + percent_diff = (`RF-logistic` - RF) / abs(RF) * 100) + } else { + results_ls[[fpath]] <- plt_df %>% + dplyr::filter(model %in% c("RF", "RF-ridge")) %>% + dplyr::select(y_task, model, metric, mean) %>% + tidyr::pivot_wider(id_cols = c(y_task, metric), names_from = "model", values_from = "mean") %>% + dplyr::mutate(diff = `RF-ridge` - RF, + percent_diff = (`RF-ridge` - RF) / abs(RF) * 100) + } +} + +plt_reg <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>% + dplyr::filter(metric == "r2") %>% + dplyr::filter(RF > 0.1) %>% + dplyr::mutate( + metric = ifelse(metric == "r2", "R-squared", metric) + ) %>% + ggplot2::ggplot() + + ggplot2::aes(x = RF, y = percent_diff, label = y_task) + + ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") + + ggplot2::geom_point(size = 3) + + ggplot2::geom_hline(yintercept = 0, linetype = "dashed") + + ggrepel::geom_label_repel(fill = "white") + + ggplot2::facet_wrap(~ metric, scales = "free") + + custom_theme + +plt_reg_all <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>% + dplyr::filter(metric == "r2") %>% + dplyr::mutate( + metric = ifelse(metric == "r2", "R-squared", metric) + ) %>% + ggplot2::ggplot() + + ggplot2::aes(x = RF, y = percent_diff, label = y_task) + + ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") + + ggplot2::geom_point(size = 3) + + ggplot2::geom_hline(yintercept = 0, linetype = "dashed") + + ggrepel::geom_label_repel(fill = "white") + + ggplot2::facet_wrap(~ metric, scales = "free") + + custom_theme + +if (params$for_paper) { + ggplot2::ggsave( + plot = plt_reg_all, + file.path(figures_dir, "prediction_results_appendix.pdf"), + units = "in", width = 8, height = fig_height * 0.75 + ) +} + +plt_ls <- list(reg = plt_reg) +for (m in c("f1", "prauc")) { + plt_df <- dplyr::bind_rows(results_ls, .id = "dataset") %>% + dplyr::filter(metric == m) %>% + dplyr::mutate( + dataset = stringr::str_remove(dataset, "^mdi_plus\\.prediction_sims\\.") %>% + stringr::str_remove("_classification-$") %>% + stringr::str_to_title() %>% + stringr::str_replace("Tcga_brca", "TCGA BRCA"), + diff = ifelse(metric == "logloss", -diff, diff), + percent_diff = ifelse(metric == "logloss", -percent_diff, percent_diff), + metric = forcats::fct_recode(metric, + "Accuracy" = "accuracy", + "F1" = "f1", + "Negative Log-loss" = "logloss", + "AUPRC" = "prauc", + "AUROC" = "rocauc") + ) + plt_ls[[m]] <- plt_df %>% + ggplot2::ggplot() + + ggplot2::aes(x = RF, y = percent_diff, label = dataset) + + ggplot2::labs(x = sprintf("RF Test %s", plt_df$metric[1]), + y = "% Change Using RF+ (logistic)") + + ggplot2::geom_point(size = 3) + + ggplot2::geom_hline(yintercept = 0, linetype = "dashed") + + ggrepel::geom_label_repel(point.padding = 0.5, fill = "white") + + ggplot2::facet_wrap(~ metric, scales = "free", nrow = 1) + + custom_theme +} + +plt_ls[[3]] <- plt_ls[[3]] + + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + +plt <- plt_ls[[1]] + patchwork::plot_spacer() + plt_ls[[2]] + plt_ls[[3]] + + patchwork::plot_layout(nrow = 1, widths = c(1, 0.3, 1, 1)) + +# plt +# grid::grid.draw( +# grid::linesGrob(x = grid::unit(c(0.36, 0.36), "npc"), +# y = grid::unit(c(0.06, 0.98), "npc")) +# ) +# # ggplot2::ggsave( +# # file.path(figures_dir, "prediction_results_main.pdf"), +# # units = "in", width = 14, height = fig_height * 0.75 +# # ) + +vthemes::subchunkify(plt, i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) +chunk_idx <- chunk_idx + 1 +``` + + +# MDI+ Modeling Choices {.tabset .tabset-vmodern} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c(manual_color_palette_choices[method_labels_choices == "MDI"], + manual_color_palette_choices[method_labels_choices == "MDI-oob"], + "#AF4D98", "#FFD92F", "#FC8D62", "black") +show_methods <- c("MDI_RF", "MDI-oob_RF", "MDI+_raw_RF", "MDI+_loo_RF", "MDI+_ols_raw_loo_RF", "MDI+_ridge_raw_loo_RF") +method_labels <- c("MDI", "MDI-oob", "MDI+ (raw only)", "MDI+ (loo only)", "MDI+ (raw+loo only)", "MDI+ (ridge+raw+loo)") +alpha_values <- rep(1, length(method_labels)) +legend_position <- c(0.63, 0.4) + +vary_param_name <- "heritability_sample_row_n" +y_models <- c("linear", "hier_poly_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq")#, "splicing") +min_samples_per_leafs <- c(5, 1) + +remove_x_axis_models <- c("enhancer") +keep_legend_x_models <- c("enhancer") +keep_legend_y_models <- c("linear") + +for (y_model in y_models) { + cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model)) + plt_ls <- list() + for (min_samples_per_leaf in min_samples_per_leafs) { + cat(sprintf("\n\n#### min_samples_per_leaf = %s \n\n", min_samples_per_leaf)) + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + metric_name <- dplyr::case_when( + metric == "rocauc" ~ "AUROC", + metric == "prauc" ~ "PRAUC" + ) + fname <- file.path(results_dir, + sprintf("mdi_plus.other_regression_sims.modeling_choices_min_samples%s.%s", min_samples_per_leaf, sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (!file.exists(sprintf("%s.csv", fname))) { + next + } + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + + ggplot2::theme( + legend.text = ggplot2::element_text(size = 12) + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "modeling_choices", + sprintf("regression_sims_choices_min_samples%s_%s_%s_%s_errbars.pdf", + min_samples_per_leaf, y_model, x_model, metric)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = metric, + x_str = "sample_row_n", + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } + } +} + +``` + + +# MDI+ GLM/Metric Choices {.tabset .tabset-vmodern} + +## Held-out Test Prediction Scores {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("#A5AE9E", "black", "brown") +show_methods <- c("RF", "RF-ridge", "RF-lasso") +method_labels <- c("RF", "RF+ridge", "RF+lasso") +alpha_values <- rep(1, length(method_labels)) +legend_position <- c(0.63, 0.4) + +vary_param_name <- "heritability_sample_row_n" +y_models <- c("linear", "hier_poly_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq")#, "splicing") +prediction_metrics <- c( + "r2" #, "mean_absolute_error" + # "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss" +) + +remove_x_axis_models <- c("enhancer") +keep_legend_x_models <- c("enhancer") +keep_legend_y_models <- c("linear") + +for (y_model in y_models) { + cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model)) + plt_ls <- list() + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + fname <- file.path(results_dir, + sprintf("mdi_plus.glm_metric_choices_sims.regression_prediction_sims.%s", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (!file.exists(sprintf("%s.csv", fname))) { + next + } + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results(prediction = TRUE) + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results(prediction = TRUE) + saveRDS(results, sprintf("%s.rds", fname)) + } + for (m in prediction_metrics) { + metric_name <- dplyr::case_when( + m == "r2" ~ "R-squared", + m == "mae" ~ "Mean Absolute Error" + ) + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = NULL, + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = TRUE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method" + ) + + ggplot2::theme( + legend.text = ggplot2::element_text(size = 12) + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (x_model %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none") + } + plt_ls[[as.character(h)]] <- plt + } + plt <- patchwork::wrap_plots(plt_ls, nrow = 1) + ggplot2::ggsave( + file.path(figures_dir, "glm_metric_choices", + sprintf("regression_sims_glm_metric_choices_prediction_%s_%s_%s_errbars.pdf", + y_model, x_model, m)), + plot = plt, units = "in", width = 14, height = height + ) + } else { + plt <- results %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = "heritability_name", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", + title = sprintf("%s", sim_title) + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.3) + chunk_idx <- chunk_idx + 1 + } + } + } +} + +``` + +## Stability Scores - Regression {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("black", "black", "brown", "brown", "gray", + manual_color_palette_choices[method_labels_choices == "MDI"], + manual_color_palette_choices[method_labels_choices == "MDI-oob"]) +show_methods <- c("MDI+_ridge_r2_RF", "MDI+_ridge_neg_mae_RF", "MDI+_lasso_r2_RF", "MDI+_lasso_neg_mae_RF", "MDI+_ensemble_RF", "MDI_RF", "MDI-oob_RF") +method_labels <- c("MDI+ (ridge, r-squared)", "MDI+ (ridge, MAE)", "MDI+ (lasso, r-squared)", "MDI+ (lasso, MAE)", "MDI+ (ensemble)", "MDI", "MDI-oob") +manual_linetype_palette <- c(1, 2, 1, 2, 1, 1, 1) +alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2)) +legend_position <- c(0.63, 0.4) + +vary_param_name <- "heritability_sample_row_n" +# y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r") +# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing") +y_models <- c("linear", "hier_poly_3m_2r") +x_models <- c("enhancer", "ccle_rnaseq") +# stability_metrics <- c("tauAP", "RBO") +stability_metrics <- c("RBO") + +remove_x_axis_models <- metric +keep_legend_x_models <- x_models +keep_legend_y_models <- y_models +keep_legend_metrics <- metric + +for (y_model in y_models) { + cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model)) + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + fname <- file.path(results_dir, + sprintf("mdi_plus.glm_metric_choices_sims.regression_sims.%s", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (!file.exists(sprintf("%s.csv", fname))) { + next + } + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + plt_ls <- list() + for (m in c(metric, stability_metrics)) { + metric_name <- dplyr::case_when( + m == "rocauc" ~ "AUROC", + m == "prauc" ~ "PRAUC", + TRUE ~ m + ) + if (params$for_paper) { + for (h in heritabilities) { + plt <- results %>% + dplyr::filter(heritability == !!h) %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = NULL, + linetype_str = "method", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = FALSE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", linetype = "Method" + ) + + ggplot2::scale_linetype_manual( + values = manual_linetype_palette, labels = method_labels + ) + + ggplot2::theme( + legend.text = ggplot2::element_text(size = 12), + legend.key.width = ggplot2::unit(1.9, "cm") + ) + if (h != heritabilities[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (m %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == heritabilities[length(heritabilities)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models) & + (metric %in% keep_legend_metrics))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none") + } + plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt + } + } else { + plt <- results %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = "heritability_name", + linetype_str = "method", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", linetype = "Method", + title = sprintf("%s", sim_title) + ) + + ggplot2::scale_linetype_manual( + values = manual_linetype_palette, labels = method_labels + ) + + ggplot2::theme( + legend.key.width = ggplot2::unit(1.9, "cm") + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.5) + chunk_idx <- chunk_idx + 1 + } + } + if (params$for_paper) { + nrows <- length(c(metric, stability_metrics)) + plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect") + ggplot2::ggsave( + file.path(figures_dir, "glm_metric_choices", + sprintf("regression_sims_glm_metric_choices_stability_%s_%s_errbars.pdf", + y_model, x_model)), + plot = plt, units = "in", width = 14 * 1.2, height = height * nrows + ) + } + } +} + +``` + +## Stability Scores - Classification {.tabset .tabset-pills .tabset-square} + +```{r eval = TRUE, results = "asis"} +manual_color_palette <- c("#9B5DFF", "#9B5DFF", "black", + manual_color_palette_choices[method_labels_choices == "MDI"], + manual_color_palette_choices[method_labels_choices == "MDI-oob"]) +show_methods <- c("MDI+_logistic_ridge_logloss_RF", "MDI+_logistic_ridge_auroc_RF", "MDI+_ridge_r2_RF", "MDI_RF", "MDI-oob_RF") +method_labels <- c("MDI+ (logistic, log-loss)", "MDI+ (logistic, AUROC)", "MDI+ (ridge, r-squared)", "MDI", "MDI-oob") +manual_linetype_palette <- c(1, 2, 1, 1, 1) +alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2)) +legend_position <- c(0.63, 0.4) + +vary_param_name <- "frac_label_corruption_sample_row_n" +# y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic") +# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing") +y_models <- c("logistic", "hier_poly_3m_2r_logistic") +x_models <- c("juvenile", "splicing") +# stability_metrics <- c("tauAP", "RBO") +stability_metrics <- "RBO" + +remove_x_axis_models <- metric +keep_legend_x_models <- x_models +keep_legend_y_models <- y_models +keep_legend_metrics <- metric + +for (y_model in y_models) { + cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model)) + for (x_model in x_models) { + cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model)) + sim_name <- sprintf("%s_%s_dgp", x_model, y_model) + sim_title <- dplyr::case_when( + x_model == "ccle_rnaseq" ~ "CCLE", + x_model == "splicing" ~ "Splicing", + x_model == "enhancer" ~ "Enhancer", + x_model == "juvenile" ~ "Juvenile" + ) + fname <- file.path(results_dir, + sprintf("mdi_plus.glm_metric_choices_sims.classification_sims.%s", sim_name), + paste0("varying_", vary_param_name), + paste0("seed", seed), "results") + if (!file.exists(sprintf("%s.csv", fname))) { + next + } + if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) { + results <- readRDS(sprintf("%s.rds", fname)) + if (length(setdiff(show_methods, unique(results$method))) > 0) { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + } else { + results <- data.table::fread(sprintf("%s.csv", fname)) %>% + reformat_results() + saveRDS(results, sprintf("%s.rds", fname)) + } + plt_ls <- list() + for (m in c(metric, stability_metrics)) { + metric_name <- dplyr::case_when( + m == "rocauc" ~ "AUROC", + m == "prauc" ~ "PRAUC", + TRUE ~ m + ) + if (params$for_paper) { + for (h in frac_label_corruptions) { + plt <- results %>% + dplyr::filter(frac_label_corruption_name == !!h) %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = NULL, + linetype_str = "method", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + legend_position = legend_position, + custom_theme = custom_theme, + inside_legend = FALSE + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", linetype = "Method" + ) + + ggplot2::scale_linetype_manual( + values = manual_linetype_palette, labels = method_labels + ) + + ggplot2::theme( + legend.text = ggplot2::element_text(size = 12), + legend.key.width = ggplot2::unit(1.9, "cm") + ) + if (h != frac_label_corruptions[1]) { + plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank()) + } + if (m %in% remove_x_axis_models) { + height <- 2.72 + plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank()) + } else { + height <- 3 + } + if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & + (x_model %in% keep_legend_x_models) & + (y_model %in% keep_legend_y_models) & + (metric %in% keep_legend_metrics))) { + plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none") + } + plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt + } + } else { + plt <- results %>% + plot_metrics( + metric = m, + x_str = "sample_row_n", + facet_str = "frac_label_corruption_name", + linetype_str = "method", + point_size = point_size, + line_size = line_size, + errbar_width = errbar_width, + manual_color_palette = manual_color_palette, + show_methods = show_methods, + method_labels = method_labels, + alpha_values = alpha_values, + custom_theme = custom_theme + ) + + ggplot2::labs( + x = "Sample Size", y = metric_name, + color = "Method", alpha = "Method", linetype = "Method", + title = sprintf("%s", sim_title) + ) + + ggplot2::scale_linetype_manual( + values = manual_linetype_palette, labels = method_labels + ) + + ggplot2::theme( + legend.key.width = ggplot2::unit(1.9, "cm") + ) + vthemes::subchunkify(plot_fun(plt), i = chunk_idx, + fig_height = fig_height, fig_width = fig_width * 1.5) + chunk_idx <- chunk_idx + 1 + } + } + if (params$for_paper) { + nrows <- length(c(metric, stability_metrics)) + plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect") + ggplot2::ggsave( + file.path(figures_dir, "glm_metric_choices", + sprintf("classification_sims_glm_metric_choices_stability_%s_%s_errbars.pdf", + y_model, x_model)), + plot = plt, units = "in", width = 14 * 1.2, height = height * nrows + ) + } + } +} + +``` diff --git a/feature_importance/notebooks/mdi_plus/01_paper_figures.html b/feature_importance/notebooks/mdi_plus/01_paper_figures.html new file mode 100644 index 0000000..8561f39 --- /dev/null +++ b/feature_importance/notebooks/mdi_plus/01_paper_figures.html @@ -0,0 +1,5294 @@ + + + + + + + + + + + + MDI+ Simulation Results Summary + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+
+ +
+
+
+
+ +
+
+ + + +
+ + +
+ + + + + +
+ + +
+ +
+
+

Regression Simulations

+
+

linear

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

lss_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

linear_lss_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+
+

Classification Simulations

+
+

logistic

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

lss_3m_2r_logistic

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

hier_poly_3m_2r_logistic

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

linear_lss_3m_2r_logistic

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+
+

Robust Simulations

+
+

linear

+
+

enhancer

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+

ccle_rnaseq

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+
+

lss_3m_2r

+
+

enhancer

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+

ccle_rnaseq

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+

ccle_rnaseq

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+
+

linear_lss_3m_2r

+
+

enhancer

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+

ccle_rnaseq

+
+

Mean Shift = 10

+

+
+
+

Mean Shift = 25

+

+
+
+
+
+
+

Correlation Bias Simulations

+
+

Main Figures

+
+

PVE = 0.1

+

+
+
+

PVE = 0.4

+

+
+
+

Number of RF Splits

+

+
+
+
+

Appendix Figures

+
+

PVE = 0.1

+

+
+
+

PVE = 0.4

+

+
+
+
+
+

Entropy Bias Simulations

+
+

Main Figures

+
+

Average Rank

+

+
+
+

Number of RF Splits

+

+
+
+
+

Appendix Figures

+
+

Average Rank

+

+
+
+
+
+

CCLE Case Study

+
+

Summary

+

+
+
+

Table of Top Genes

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ +MDI+ (ridge) + +TreeSHAP + +MDI + +MDA + +MDI-oob +
+17-AAG +
+1 + +PCSK1N (1.47) + +PCSK1N (2.19) + +PCSK1N (2.19) + +PCSK1N (2.5) + +PCSK1N (9.16) +
+2 + +MMP24 (3.41) + +MMP24 (4.97) + +MMP24 (4.06) + +NQO1 (9.03) + +MMP24 (194.5) +
+3 + +RP11-109D20.2 (4.59) + +ZSCAN18 (6.94) + +ZSCAN18 (8.56) + +ZNF667-AS1 (26.59) + +ZNF667-AS1 (241.38) +
+4 + +ZSCAN18 (8.09) + +RP11-109D20.2 (7.41) + +RP11-109D20.2 (10.09) + +ZSCAN18 (44.03) + +RP11-109D20.2 (525.22) +
+5 + +NQO1 (8.84) + +NQO1 (9.53) + +NQO1 (11.81) + +TST (49.38) + +SH3BP1 (587.94) +
+AEW541 +
+1 + +TXNDC5 (1.41) + +TCEAL4 (1.62) + +TXNDC5 (1.75) + +TXNDC5 (1.5) + +TCEAL4 (5.59) +
+2 + +ATP8B2 (4.34) + +TXNDC5 (3.53) + +ATP8B2 (3.84) + +ATP8B2 (4.69) + +IQGAP2 (238.8) +
+3 + +VAV2 (6.03) + +ATP8B2 (8.38) + +VAV2 (5.41) + +VAV2 (5.84) + +RP11-343H19.2 (303.25) +
+4 + +TNFRSF17 (8.53) + +VAV2 (10.47) + +TCEAL4 (9.44) + +TCEAL4 (6.5) + +TXNDC5 (312.62) +
+5 + +TCEAL4 (9.03) + +PLEKHF1 (19.25) + +TNFRSF17 (9.69) + +TNFRSF17 (13.56) + +ATP8B2 (318.59) +
+AZD0530 +
+1 + +PRSS57 (5.16) + +SYTL1 (17.62) + +PRSS57 (7.69) + +PRSS57 (12.09) + +VTN (105.69) +
+2 + +SYTL1 (12.31) + +PRSS57 (32.5) + +SYTL1 (42.94) + +DDAH2 (440.16) + +SYTL1 (216.62) +
+3 + +STXBP2 (15.38) + +SFTA1P (36.5) + +NFE2 (43.06) + +SLC16A9 (484.34) + +STXBP2 (245.31) +
+4 + +NFE2 (23.12) + +STXBP2 (62.59) + +STXBP2 (61.62) + +STXBP2 (486.09) + +ZBED2 (466.48) +
+5 + +THEM4 (34.41) + +CLDN16 (67.28) + +SLC16A9 (61.81) + +RAPGEF3 (514.2) + +DDAH2 (472.06) +
+AZD6244 +
+1 + +LYZ (1.66) + +TOR4A (2.31) + +LYZ (1.72) + +LYZ (2.53) + +LYZ (2.09) +
+2 + +SPRY2 (2.34) + +SPRY2 (3.31) + +RP11-1143G9.4 (3.59) + +SPRY2 (2.59) + +RP11-1143G9.4 (3.59) +
+3 + +RP11-1143G9.4 (2.84) + +LYZ (3.69) + +SPRY2 (3.91) + +TOR4A (3.25) + +TOR4A (3.66) +
+4 + +ETV4 (5.22) + +ETV4 (5.19) + +TOR4A (6.41) + +RP11-1143G9.4 (4.75) + +SPRY2 (4.5) +
+5 + +TOR4A (6.41) + +RP11-1143G9.4 (6.84) + +RNF125 (6.66) + +ETV4 (6.91) + +ETV4 (6.34) +
+Erlotinib +
+1 + +CDH3 (1.47) + +CDH3 (1.84) + +CDH3 (2.03) + +CDH3 (1.97) + +CDH3 (1.88) +
+2 + +RP11-615I2.2 (2.28) + +RP11-615I2.2 (3.28) + +RP11-615I2.2 (2.88) + +RP11-615I2.2 (3.53) + +RP11-615I2.2 (3.16) +
+3 + +EGFR (4.34) + +SPRR1A (3.97) + +EGFR (3.97) + +SPRR1A (6.25) + +SPRR1A (3.84) +
+4 + +SPRR1A (4.44) + +SYTL1 (7.84) + +SPRR1A (4.19) + +GJB3 (8.78) + +EGFR (8.31) +
+5 + +GJB3 (7.44) + +EGFR (8.69) + +KRT16 (11.41) + +EGFR (9.31) + +SYTL1 (8.72) +
+Irinotecan +
+1 + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) +
+2 + +S100A16 (4.12) + +S100A16 (3.75) + +S100A16 (3.84) + +S100A16 (3.25) + +WWTR1 (6.03) +
+3 + +IFITM10 (4.19) + +IFITM10 (4.09) + +WWTR1 (4.28) + +WWTR1 (4.12) + +TRIM16L (150.38) +
+4 + +WWTR1 (4.94) + +WWTR1 (4.78) + +IFITM10 (8.03) + +RP11-359P5.1 (8.22) + +IFITM10 (163.44) +
+5 + +PPIC (7.81) + +PPIC (10.22) + +RP11-359P5.1 (8.47) + +IFITM10 (8.41) + +S100A16 (182.19) +
+L-685458 +
+1 + +PXK (2.03) + +DEF6 (4.62) + +PXK (2.03) + +PXK (3) + +PXK (2.94) +
+2 + +DEF6 (4.28) + +PXK (4.94) + +DEF6 (4.38) + +IKZF1 (3.44) + +CXorf21 (4.62) +
+3 + +CXorf21 (4.84) + +CXorf21 (5.44) + +CXorf21 (5.62) + +CXorf21 (5.75) + +DEF6 (4.66) +
+4 + +IKZF1 (6.03) + +IKZF1 (5.75) + +IKZF1 (7.91) + +DEF6 (10.31) + +IKZF1 (5.47) +
+5 + +RP11-359P5.1 (9.09) + +RP11-359P5.1 (9.94) + +RP11-359P5.1 (12.81) + +CTNNA1 (13.88) + +RP11-359P5.1 (10.06) +
+LBW242 +
+1 + +SERPINB6 (1.12) + +SERPINB6 (1) + +SERPINB6 (1.66) + +SERPINB6 (1.31) + +SERPINB6 (1.56) +
+2 + +RGS14 (5.12) + +RGS14 (6.66) + +RGS14 (3.66) + +GPT2 (45.62) + +HERC5 (54.31) +
+3 + +HERC5 (5.41) + +MAGEC1 (7.5) + +MAGEC1 (5.53) + +GBP1 (179.28) + +ITGA1 (73.34) +
+4 + +MAGEC1 (7.62) + +ITGA1 (10.53) + +GBP1 (5.78) + +ZNF32 (222.03) + +PTGS1 (296.62) +
+5 + +GBP1 (8.22) + +HERC5 (12.41) + +CCL2 (13.5) + +IGSF3 (257.88) + +GPT2 (316.12) +
+Lapatinib +
+1 + +ERBB2 (1.06) + +ERBB2 (1.5) + +ERBB2 (1.41) + +ERBB2 (1.69) + +ERBB2 (1.47) +
+2 + +PGAP3 (3.09) + +NA (6.81) + +PGAP3 (3.44) + +PGAP3 (8.03) + +NA (4.31) +
+3 + +NA (5.03) + +PGAP3 (12.41) + +IKBIP (6.19) + +C2orf54 (13.09) + +PGAP3 (14.03) +
+4 + +C2orf54 (6.91) + +DPYSL2 (16.16) + +NA (6.22) + +DPYSL2 (15.41) + +PKP3 (20.31) +
+5 + +IKBIP (8.28) + +PKP3 (16.47) + +C2orf54 (7.41) + +EMP3 (20.47) + +EMP3 (22.38) +
+Nilotinib +
+1 + +SPN (1.25) + +SPN (1.81) + +SPN (1.62) + +SPN (1.47) + +SPN (3.59) +
+2 + +GPC1 (3.5) + +GPC1 (4.38) + +GPC1 (3.5) + +GPC1 (3.44) + +SELPLG (9.03) +
+3 + +TRDC (6.62) + +SELPLG (7.5) + +TRDC (10.16) + +SELPLG (9.97) + +KLF13 (26.22) +
+4 + +SELPLG (6.78) + +KLF13 (16.97) + +LMO2 (10.44) + +TRDC (10.22) + +BCL2 (51.09) +
+5 + +LMO2 (9.44) + +TRDC (20.22) + +CISH (11.03) + +LMO2 (10.75) + +GPC1 (166.19) +
+Nutlin-3 +
+1 + +RP11-148O21.4 (1.41) + +MET (2.53) + +RP11-148O21.4 (1.59) + +RP11-148O21.4 (1.72) + +RAPGEF5 (37) +
+2 + +MET (2.53) + +RP11-148O21.4 (4.75) + +MET (4.06) + +LRRC16A (9.56) + +G6PD (63.53) +
+3 + +BLK (5.12) + +LAYN (6.41) + +BLK (5.25) + +BLK (10.97) + +MET (147.78) +
+4 + +LRRC16A (5.16) + +RPS27L (12.94) + +LRRC16A (5.97) + +MET (26.09) + +BLK (160.06) +
+5 + +LAT2 (7.34) + +ADD3 (21.03) + +LAT2 (7.78) + +LAYN (138.84) + +RP11-148O21.4 (164.94) +
+PD-0325901 +
+1 + +SPRY2 (1.16) + +SPRY2 (1.53) + +SPRY2 (1.75) + +SPRY2 (1.19) + +SPRY2 (1.75) +
+2 + +LYZ (2.72) + +ETV4 (2.88) + +LYZ (2.09) + +LYZ (2.88) + +LYZ (2.62) +
+3 + +ETV4 (2.72) + +LYZ (3.59) + +ETV4 (3.66) + +ETV4 (3.38) + +ETV4 (3.47) +
+4 + +RP11-1143G9.4 (4.34) + +TOR4A (4.56) + +RP11-1143G9.4 (4.53) + +TOR4A (4.72) + +TOR4A (4.88) +
+5 + +PLEKHG4B (5.62) + +PLEKHG4B (4.66) + +PLEKHG4B (5.31) + +RP11-1143G9.4 (5.38) + +PLEKHG4B (5.5) +
+PD-0332991 +
+1 + +SH2D3C (4.31) + +SH2D3C (6.81) + +SH2D3C (4.56) + +SH2D3C (8) + +KRT15 (10.56) +
+2 + +FMNL1 (6.56) + +FMNL1 (8.38) + +HSD3B7 (6.75) + +AL162151.3 (9.62) + +HSD3B7 (14.5) +
+3 + +HSD3B7 (6.59) + +AL162151.3 (11.59) + +FMNL1 (7.09) + +HSD3B7 (11.03) + +SEPT6 (16.44) +
+4 + +KRT15 (7.19) + +TWF1 (12.03) + +KRT15 (7.78) + +KRT15 (11.97) + +PPIC (17.03) +
+5 + +AL162151.3 (8.84) + +KRT15 (12.56) + +TWF1 (8.97) + +FMNL1 (16.34) + +AL162151.3 (18.34) +
+PF2341066 +
+1 + +ENAH (1.03) + +ENAH (1) + +ENAH (1.31) + +ENAH (1.12) + +ENAH (1.06) +
+2 + +SELPLG (2.09) + +SELPLG (2.81) + +SELPLG (2.06) + +SELPLG (2.28) + +SELPLG (2.47) +
+3 + +HGF (3.62) + +HGF (3.94) + +MET (5.44) + +HGF (7.03) + +CTD-2020K17.3 (10.31) +
+4 + +CTD-2020K17.3 (9.72) + +CTD-2020K17.3 (9.41) + +HGF (6) + +MET (10.69) + +HGF (14.06) +
+5 + +MET (10.53) + +MET (11.5) + +MLKL (12) + +CTD-2020K17.3 (11.88) + +DOK2 (14.41) +
+PHA-665752 +
+1 + +ARHGAP4 (1) + +ARHGAP4 (1.06) + +ARHGAP4 (1.06) + +ARHGAP4 (1.09) + +ARHGAP4 (1) +
+2 + +CTD-2020K17.3 (2.88) + +CTD-2020K17.3 (2.66) + +CTD-2020K17.3 (4.25) + +CTD-2020K17.3 (2.56) + +FMNL1 (4.22) +
+3 + +FMNL1 (4.88) + +FMNL1 (7.44) + +PFN2 (10.84) + +PFN2 (24.31) + +CTD-2020K17.3 (5.44) +
+4 + +PFN2 (8.06) + +PGPEP1 (13.28) + +FMNL1 (11.78) + +FMNL1 (33.12) + +INHBB (216.5) +
+5 + +PGPEP1 (10.12) + +INHBB (18.88) + +FDFT1 (18.66) + +MICB (59.69) + +PGPEP1 (335.97) +
+PLX4720 +
+1 + +RXRG (1) + +RXRG (1.16) + +RXRG (1) + +RXRG (1) + +RXRG (1.03) +
+2 + +MMP8 (5.19) + +MMP8 (3.97) + +MMP8 (5.16) + +MMP8 (5.56) + +MMP8 (4.88) +
+3 + +RP11-164J13.1 (6.28) + +RP11-599J14.2 (7.22) + +MYO5A (7.09) + +MYO5A (6.81) + +MYO5A (8.75) +
+4 + +RP11-599J14.2 (7) + +AP1S2 (8.47) + +LYST (7.97) + +LYST (11.31) + +AP1S2 (10.53) +
+5 + +RP4-718J7.4 (7.84) + +LYST (9.34) + +RP11-599J14.2 (8.41) + +RP4-718J7.4 (13.22) + +RP11-599J14.2 (12.69) +
+Paclitaxel +
+1 + +MMP24 (1.09) + +MMP24 (1.28) + +MMP24 (1.22) + +MMP24 (2.38) + +SH3BP1 (10.38) +
+2 + +AGAP2 (3.16) + +SH3BP1 (2.75) + +AGAP2 (3.44) + +SH3BP1 (2.88) + +PRODH (60.41) +
+3 + +SH3BP1 (3.5) + +AGAP2 (4.06) + +SH3BP1 (3.78) + +SLC38A5 (3.84) + +AGAP2 (157.41) +
+4 + +SLC38A5 (4.34) + +SLC38A5 (4.22) + +PTTG1IP (3.91) + +AGAP2 (5.88) + +SLC38A5 (179.34) +
+5 + +PTTG1IP (4.72) + +PTTG1IP (4.34) + +SLC38A5 (4.31) + +PTTG1IP (6.97) + +MMP24 (308.66) +
+Panobinostat +
+1 + +AGAP2 (1.12) + +AGAP2 (1.56) + +AGAP2 (1.81) + +AGAP2 (2.16) + +CYR61 (2.56) +
+2 + +CYR61 (2.44) + +CYR61 (2.09) + +CYR61 (2.16) + +CYR61 (2.78) + +AGAP2 (3.25) +
+3 + +RPL39P5 (4.19) + +RPL39P5 (4.41) + +RPL39P5 (3.75) + +RPL39P5 (3.88) + +RPL39P5 (152.44) +
+4 + +WWTR1 (5.16) + +WWTR1 (5.78) + +WWTR1 (5.94) + +WWTR1 (6.53) + +S100A2 (316.66) +
+5 + +MYOF (6.56) + +MYOF (6.16) + +IKZF1 (12.41) + +IKZF1 (9.72) + +MYOF (366.28) +
+RAF265 +
+1 + +CMTM3 (1.34) + +CMTM3 (1.5) + +CMTM3 (1.69) + +SH2B3 (34) + +CMTM3 (7.06) +
+2 + +SYT17 (5.69) + +SYT17 (5.66) + +SYT17 (7.69) + +CMTM3 (155.25) + +SH2B3 (470.66) +
+3 + +SH2B3 (6.03) + +SH2B3 (8.91) + +SH2B3 (17.5) + +SLC29A3 (159) + +SYT17 (652.84) +
+4 + +EMILIN2 (11.94) + +SLC29A3 (11.84) + +STAT5A (19.66) + +PRKCQ (235) + +RGS16 (713.47) +
+5 + +STAT5A (12.47) + +NA (19.91) + +SLC29A3 (22.22) + +LCP2 (259.7) + +AC007620.3 (1087.11) +
+Sorafenib +
+1 + +PXK (4.47) + +TP63 (6.72) + +PXK (4.12) + +FAM212A (6.41) + +PXK (9.34) +
+2 + +P2RX1 (4.62) + +P2RX1 (7.78) + +FAM212A (5.75) + +P2RX1 (8) + +ARHGAP9 (34.25) +
+3 + +FAM212A (4.69) + +PXK (8.88) + +STAC3 (7.16) + +SEC31B (41.69) + +P2RX1 (156.91) +
+4 + +STAC3 (5.16) + +FAM212A (16.97) + +P2RX1 (7.25) + +ARHGAP9 (43.19) + +FAM212A (187.31) +
+5 + +ARHGAP9 (7.91) + +STAC3 (20.16) + +TP63 (40.97) + +CXCL8 (57.72) + +SEC31B (222.41) +
+TAE684 +
+1 + +SELPLG (1.09) + +SELPLG (1.06) + +SELPLG (1.12) + +SELPLG (1.03) + +SELPLG (1.34) +
+2 + +IL6R (3.19) + +ARID3A (8.12) + +IL6R (3.34) + +IL6R (6.41) + +ARID3A (18.31) +
+3 + +NFIL3 (6.34) + +GALNT18 (8.62) + +NFIL3 (6.25) + +NFIL3 (8.06) + +FMNL1 (25.25) +
+4 + +ARID3A (7) + +IL6R (10.16) + +RRAS2 (10.19) + +ARID3A (16.38) + +RP11-334A14.2 (144.34) +
+5 + +RRAS2 (8.66) + +PPP2R3A (15.69) + +FMNL1 (10.19) + +RRAS2 (17.97) + +PPP2R3A (165.84) +
+TKI258 +
+1 + +TWF1 (2.44) + +TWF1 (3.31) + +TWF1 (2.41) + +TWF1 (2.28) + +LAPTM5 (20.84) +
+2 + +SLC43A1 (2.88) + +GPR162 (3.75) + +PRTN3 (3.31) + +LAPTM5 (5.34) + +SLC43A1 (156.12) +
+3 + +PRTN3 (5.25) + +SLC43A1 (6.84) + +SLC43A1 (6.06) + +PRTN3 (6.12) + +TWF1 (163.78) +
+4 + +LAPTM5 (5.78) + +TTC28 (8.94) + +LAT2 (6.62) + +SLC43A1 (14.91) + +GPR162 (169.66) +
+5 + +LAT2 (5.94) + +LAPTM5 (11.53) + +LAPTM5 (8.19) + +LAT2 (17.41) + +LYL1 (387.97) +
+Topotecan +
+1 + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) + +SLFN11 (1) +
+2 + +HSPB8 (2.16) + +HSPB8 (2.28) + +HSPB8 (2.84) + +HSPB8 (2.06) + +HSPB8 (2.62) +
+3 + +PPIC (5.69) + +OSGIN1 (5.28) + +OSGIN1 (5.31) + +PPIC (7.81) + +OSGIN1 (7.19) +
+4 + +OSGIN1 (5.88) + +AGAP2 (8.81) + +PPIC (6.81) + +RP11-359P5.1 (10.75) + +AGAP2 (15.25) +
+5 + +AGAP2 (6.53) + +PPIC (8.91) + +RP11-359P5.1 (8.25) + +CORO1A (12.06) + +HMGB2 (21.53) +
+ZD-6474 +
+1 + +MAP3K12 (1) + +MAP3K12 (1) + +MAP3K12 (1) + +MAP3K12 (1.09) + +MAP3K12 (1) +
+2 + +PIM1 (5.47) + +CTSH (21.88) + +PIM1 (10.28) + +PIM1 (31.44) + +SCD5 (339.58) +
+3 + +PRKCQ (9.03) + +TIMP1 (26.31) + +PRKCQ (15.28) + +DYNLT3 (192.12) + +ITGA10 (527.41) +
+4 + +CTSH (12.91) + +PRKCQ (26.47) + +CTSH (18.16) + +TIMP1 (211.25) + +TIMP1 (569.81) +
+5 + +ITGA10 (20.81) + +PIM1 (33.81) + +ANXA5 (78.78) + +EPHA1 (320.22) + +CTSH (631.5) +
+
+
+

17-AAG

+

+

+
+
+

AEW541

+

+

+
+
+

AZD0530

+

+

+
+
+

AZD6244

+

+

+
+
+

Erlotinib

+

+

+
+
+

Irinotecan

+

+

+
+
+

L-685458

+

+

+
+
+

LBW242

+

+

+
+
+

Lapatinib

+

+

+
+
+

Nilotinib

+

+

+
+
+

Nutlin-3

+

+

+
+
+

PD-0325901

+

+

+
+
+

PD-0332991

+

+

+
+
+

PF2341066

+

+

+
+
+

PHA-665752

+

+

+
+
+

PLX4720

+

+

+
+
+

Paclitaxel

+

+

+
+
+

Panobinostat

+

+

+
+
+

RAF265

+

+

+
+
+

Sorafenib

+

+

+
+
+

TAE684

+

+

+
+
+

TKI258

+

+

+
+
+

Topotecan

+

+

+
+
+

ZD-6474

+

+

+
+
+
+

TCGA BRCA Case Study

+
+

Non-zero Stability Scores Per Method

+

+
+
+

Stability of Top Features

+

+
+
+

Table of Top Genes

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+Rank + +MDI+ (ridge) + +MDI+ (logistic) + +MDA + +TreeSHAP + +MDI +
+1 + +ESR1 (1.91) + +ESR1 (1.91) + +ESR1 (4.5) + +ESR1 (7.62) + +ESR1 (13.91) +
+2 + +FOXA1 (4.25) + +GATA3 (4.5) + +GATA3 (6.38) + +TPX2 (10.41) + +TPX2 (15.34) +
+3 + +FOXC1 (6.12) + +FOXA1 (5.09) + +FOXA1 (8.11) + +GATA3 (19.62) + +FOXM1 (22.84) +
+4 + +GATA3 (6.97) + +TPX2 (6.81) + +TPX2 (10.12) + +FOXM1 (20.06) + +MLPH (24.97) +
+5 + +AGR3 (7.94) + +AGR3 (10.22) + +MLPH (10.16) + +FOXA1 (20.72) + +FOXA1 (25.66) +
+6 + +MLPH (8.16) + +FOXC1 (12.94) + +AGR3 (12.94) + +CDK1 (22.38) + +GATA3 (30.44) +
+7 + +TPX2 (11.03) + +MLPH (15.69) + +TBC1D9 (14.22) + +MLPH (22.53) + +CDK1 (31.41) +
+8 + +TBC1D9 (14.44) + +FOXM1 (18.12) + +FOXC1 (15.09) + +AGR3 (25.88) + +THSD4 (34.69) +
+9 + +FOXM1 (18.66) + +TBC1D9 (21.03) + +FOXM1 (19.88) + +PLK1 (28.47) + +FOXC1 (35) +
+10 + +THSD4 (21.78) + +THSD4 (23.66) + +THSD4 (21.28) + +TBC1D9 (29.84) + +TBC1D9 (35.44) +
+11 + +SPDEF (25.81) + +CDK1 (24.44) + +CDK1 (21.77) + +FOXC1 (30.44) + +AGR3 (36.44) +
+12 + +CA12 (29.44) + +MYBL2 (25.34) + +XBP1 (24.88) + +MYBL2 (33.06) + +PLK1 (36.81) +
+13 + +CDK1 (36.09) + +RACGAP1 (26.81) + +KIF2C (27.95) + +THSD4 (33.53) + +MYBL2 (45.22) +
+14 + +GABRP (36.72) + +ASPM (27.56) + +PLK1 (28.58) + +KIF2C (35.56) + +KIF2C (46.22) +
+15 + +PLK1 (37.97) + +PLK1 (28.84) + +MYBL2 (30.97) + +ASPM (40.59) + +ASPM (49.72) +
+16 + +FAM171A1 (38.59) + +UBE2C (30.41) + +GABRP (31.94) + +GMPS (41.41) + +GMPS (52.03) +
+17 + +ASPM (39.5) + +SPAG5 (31.81) + +ASPM (35.97) + +XBP1 (50.44) + +SPDEF (58.88) +
+18 + +SFRP1 (40.25) + +GMPS (35.34) + +FAM171A1 (38.55) + +CENPF (56.69) + +FAM171A1 (68.56) +
+19 + +XBP1 (40.44) + +KIF2C (36.03) + +CA12 (40.16) + +MKI67 (59.69) + +CA12 (69.59) +
+20 + +MYBL2 (43.12) + +CA12 (37.31) + +CDC20 (48.45) + +CA12 (59.91) + +UBE2C (69.78) +
+21 + +TFF3 (43.12) + +RRM2 (46.09) + +SPDEF (53.05) + +RACGAP1 (60.53) + +TFF3 (70) +
+22 + +KIF2C (44.62) + +CENPF (46.31) + +UBE2C (54.27) + +SPAG5 (61.94) + +CENPF (70.31) +
+23 + +PRR15 (44.88) + +GABRP (47.59) + +ANXA9 (55.41) + +FAM171A1 (63) + +RACGAP1 (70.94) +
+24 + +AGR2 (45) + +SFRP1 (49.34) + +C1orf64 (57.22) + +KIF11 (63.56) + +SPAG5 (71.81) +
+25 + +MIA (45.31) + +XBP1 (49.78) + +RACGAP1 (57.97) + +ANLN (64.28) + +KIF11 (73.12) +
+
+
+

Prediction using Top Features

+

+
+
+
+

Misspecified Models

+
+

linear

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

lss_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

linear_lss_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+
+

Varying Sparsity

+
+

linear

+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

lss_3m_2r

+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

hier_poly_3m_2r

+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+

linear_lss_3m_2r

+
+

juvenile

+

+
+
+

splicing

+

+
+
+
+
+

Varying # Features

+
+

linear

+
+

ccle_rnaseq

+

+
+
+
+

lss_3m_2r

+
+

ccle_rnaseq

+

+
+
+
+

hier_poly_3m_2r

+
+

ccle_rnaseq

+

+
+
+
+

linear_lss_3m_2r

+
+

ccle_rnaseq

+

+
+
+
+
+

Prediction Results

+

+
+
+

MDI+ Modeling Choices

+
+

linear

+
+

enhancer

+
+

min_samples_per_leaf = 5

+

+
+
+

min_samples_per_leaf = 1

+

+
+
+
+

ccle_rnaseq

+
+

min_samples_per_leaf = 5

+

+
+
+

min_samples_per_leaf = 1

+

+
+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+
+

min_samples_per_leaf = 5

+

+
+
+

min_samples_per_leaf = 1

+

+
+
+
+

ccle_rnaseq

+
+

min_samples_per_leaf = 5

+

+
+
+

min_samples_per_leaf = 1

+

+
+
+
+
+
+

MDI+ GLM/Metric Choices

+
+

Held-out Test Prediction Scores

+
+

linear

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+

+
+
+

ccle_rnaseq

+

+
+
+
+
+

Stability Scores - Regression

+
+

linear

+
+

enhancer

+

+

+
+
+

ccle_rnaseq

+

+

+
+
+
+

hier_poly_3m_2r

+
+

enhancer

+

+

+
+
+

ccle_rnaseq

+

+

+
+
+
+
+

Stability Scores - Classification

+
+

logistic

+
+

juvenile

+

+

+
+
+

splicing

+

+

+
+
+
+

hier_poly_3m_2r_logistic

+
+

juvenile

+

+

+
+
+

splicing

+

+

+
+
+
+
+
+ +
+ + + +
+
+ + +
+
---
title: "MDI+ Simulation Results Summary"
author: ""
date: "`r format(Sys.time(), '%B %d, %Y')`"
output: vthemes::vmodern
params:
  results_dir: 
    label: "Results directory"
    value: "/global/scratch/users/tiffanytang/feature_importance/results_final/"
  seed:
    label: "Seed"
    value: 12345
  for_paper:
    label: "Export plots for paper"
    value: FALSE
  use_cached:
    label: "Use cached .rds files"
    value: TRUE
  interactive:
    label: "Interactive plots"
    value: FALSE
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)

library(magrittr)
library(patchwork)
chunk_idx <- 1

# set parameters
results_dir <- params$results_dir
seed <- params$seed
tables_dir <- file.path("tables")
figures_dir <- file.path("figures")
figures_subdirs <- c("regression_sims", "classification_sims", "robust_sims",
                     "misspecified_regression_sims", "varying_sparsity",
                     "varying_p", "modeling_choices", "glm_metric_choices")
for (figures_subdir in figures_subdirs) {
  if (!dir.exists(file.path(figures_dir, figures_subdir))) {
    dir.create(file.path(figures_dir, figures_subdir), recursive = TRUE)
  }
}

# miscellaneous helper variables
heritabilities <- c(0.1, 0.2, 0.4, 0.8)
frac_label_corruptions <- c("0.25", "0.15", "0.05", "0")
corrupt_sizes <- c(0.05, 0.025, 0.01, 0)
mean_shifts <- c(10, 25)
metric <- "rocauc"

# plot options
point_size <- 2
line_size <- 1
errbar_width <- 0
if (params$interactive) {
  plot_fun <- plotly::ggplotly
} else {
  plot_fun <- function(x) x
}

manual_color_palette_choices <- c(
  "black", "black", "#9B5DFF", "blue",
  "orange", "#71beb7", "#218a1e", "#cc3399"
)
show_methods_choices <- c(
  "GMDI_ridge_RF", "GMDI_ridge_loo_r2_RF", "GMDI_logistic_logloss_RF", "GMDI_Huber_loo_huber_loss_RF",
  "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF"
)
method_labels_choices <- c(
  "MDI+ (ridge)", "MDI+ (ridge)", "MDI+ (logistic)", "MDI+ (Huber)",
  "MDA", "MDI", "MDI-oob", "TreeSHAP"
)
color_df <- tibble::tibble(
  color = manual_color_palette_choices,
  name = show_methods_choices,
  label = method_labels_choices
)

manual_color_palette_all <- NULL
show_methods_all <- NULL
method_labels_all <- ggplot2::waiver()

custom_theme <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
  axis.text = ggplot2::element_text(size = 14),
  axis.title = ggplot2::element_text(size = 20, face = "plain"),
  legend.text = ggplot2::element_text(size = 14),
  plot.title = ggplot2::element_blank()
  # plot.title = ggplot2::element_text(size = 12, face = "plain", hjust = 0.5)
)
custom_theme_with_legend <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.title = ggplot2::element_text(size = 12, face = "plain"),
  legend.text = ggplot2::element_text(size = 9),
  legend.text.align = 0,
  plot.title = ggplot2::element_blank()
)
custom_theme_without_legend <- vthemes::theme_vmodern(
  size_preset = "medium", bg_color = "white", grid_color = "white",
  axis.title = ggplot2::element_text(size = 12, face = "plain"),
  legend.title = ggplot2::element_blank(),
  legend.text = ggplot2::element_text(size = 9),
  legend.text.align = 0,
  plot.title = ggplot2::element_blank()
)

fig_height <- 6
fig_width <- 10

source("../../scripts/viz.R", chdir = TRUE)
```


# Regression Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.35)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- y_models

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "regression_sims",
                  sprintf("regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```

```{r eval = TRUE, results = "asis"}
y_models <- c("lss_3m_2r", "hier_poly_3m_2r")
x_models <- c("splicing")

remove_x_axis_models <- c("lss_3m_2r")
keep_legend_x_models <- x_models
keep_legend_y_models <- c("lss_3m_2r")

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "regression_sims",
                  sprintf("main_regression_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    }
  }
}

```


# Classification Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "GMDI_logistic_logloss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
legend_position <- c(0.73, 0.4)

vary_param_name <- "frac_label_corruption_sample_row_n"
y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- y_models

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in frac_label_corruptions) {
        plt <- results %>%
          dplyr::filter(frac_label_corruption_name == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != frac_label_corruptions[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          height <- 2.72
        } else {
          height <- 3
        }
        if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "classification_sims",
                  sprintf("classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "frac_label_corruption_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}
```

```{r eval = TRUE, results = "asis"}
y_models <- c("logistic", "linear_lss_3m_2r_logistic")
x_models <- c("ccle_rnaseq")

remove_x_axis_models <- c("logistic")
keep_legend_x_models <- x_models
keep_legend_y_models <- c("logistic")

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, paste0("mdi_plus.classification_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      for (h in frac_label_corruptions) {
        plt <- results %>%
          dplyr::filter(frac_label_corruption_name == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != frac_label_corruptions[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          height <- 2.72
        } else {
          height <- 3
        }
        if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "classification_sims",
                  sprintf("main_classification_sims_%s_%s_%s_errbars.pdf", y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    }
  }
}
```


# Robust Simulations {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_loo_r2_RF", "GMDI_Huber_loo_huber_loss_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))
legend_position <- c(0.73, 0.4)

vary_param_name <- "corrupt_size_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")

remove_x_axis_models <- c(10)
keep_legend_x_models <- c("enhancer", "ccle_rnaseq")
keep_legend_y_models <- c("linear", "linear_lss_3m_2r")
keep_legend_mean_shifts <- c(10)

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle} \n\n", x_model))
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    for (mean_shift in mean_shifts) {
      cat(sprintf("\n\n#### Mean Shift = %s \n\n", mean_shift))
      plt_ls <- list()
      sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
      fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC",
        TRUE ~ metric
      )
      if (params$for_paper) {
        tmp <- results %>%
          dplyr::filter(corrupt_size_name %in% corrupt_sizes,
                        method %in% show_methods) %>%
          dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
          dplyr::summarise(mean = mean(.data[[metric]]))
        min_y <- min(tmp$mean)
        max_y <- max(tmp$mean)
        for (h in corrupt_sizes) {
          plt <- results %>%
            dplyr::filter(corrupt_size_name == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            )
          if (h != corrupt_sizes[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (mean_shift %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == corrupt_sizes[length(corrupt_sizes)]) & 
                (mean_shift %in% keep_legend_mean_shifts) &
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "robust_sims",
                    sprintf("robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = "corrupt_size_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```

```{r eval = TRUE, results = "asis"}
y_models <- c("lss_3m_2r")
x_models <- c("enhancer")

remove_x_axis_models <- c(10)
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_mean_shifts <- c(10)

if (params$for_paper) {
  for (y_model in y_models) {
    for (x_model in x_models) {
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      for (mean_shift in mean_shifts) {
        plt_ls <- list()
        sim_name <- sprintf("%s_%s_%sMS_robust_dgp", x_model, y_model, mean_shift)
        fname <- file.path(results_dir, paste0("mdi_plus.robust_sims.", sim_name),
                           paste0("varying_", vary_param_name), 
                           paste0("seed", seed), "results")
        if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
          results <- readRDS(sprintf("%s.rds", fname))
          if (length(setdiff(show_methods, unique(results$method))) > 0) {
            results <- data.table::fread(sprintf("%s.csv", fname)) %>%
              reformat_results()
            saveRDS(results, sprintf("%s.rds", fname))
          }
        } else {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
        metric_name <- dplyr::case_when(
          metric == "rocauc" ~ "AUROC",
          metric == "prauc" ~ "PRAUC",
          TRUE ~ metric
        )
        tmp <- results %>%
          dplyr::filter(corrupt_size_name %in% corrupt_sizes,
                        method %in% show_methods) %>%
          dplyr::group_by(sample_row_n, corrupt_size_name, method) %>%
          dplyr::summarise(mean = mean(.data[[metric]]))
        min_y <- min(tmp$mean)
        max_y <- max(tmp$mean)
        for (h in corrupt_sizes) {
          plt <- results %>%
            dplyr::filter(corrupt_size_name == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            )
          if (h != corrupt_sizes[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (mean_shift %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == corrupt_sizes[length(corrupt_sizes)]) & 
                (mean_shift %in% keep_legend_mean_shifts) &
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "robust_sims",
                    sprintf("main_robust_sims_%s_%s_%sMS_%s_errbars.pdf", y_model, x_model, mean_shift, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      }
    }
  }
}

```


# Correlation Bias Simulations {.tabset .tabset-vmodern}

## Main Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "#218a1e", "orange", "#71beb7", "#cc3399")
show_methods <- c("GMDI", "MDI-oob", "MDA", "MDI",  "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDI-oob", "MDA", "MDI", "TreeSHAP")

custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
keep_heritabilities <- c(0.1, 0.4)
sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_rho",
  sprintf("seed%s", seed),
  "results.csv"
)
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities)

n <- 250
p <- 100
sig_ids <- 0:5
cnsig_ids <- 6:49
nreps <- length(unique(results$rep))

#### Examine average rank across varying correlation levels ####
plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "rho_name",
  param_name = NULL,
  sig_ids = sig_ids,
  cnsig_ids = cnsig_ids,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  cat(sprintf("\n\n### PVE = %s\n\n", h))
  plt_df <- plt_base$agg[[idx]]$data
  plt_ls <- list()
  for (method in method_labels) {
    plt_ls[[method]] <- plt_df %>%
      dplyr::filter(fi == method) %>%
      ggplot2::ggplot() +
      ggplot2::aes(x = rho_name, y = .mean, color = group) +
      ggplot2::geom_line() +
      ggplot2::geom_ribbon(
        ggplot2::aes(x = rho_name, 
                     ymin = .mean - (.sd / sqrt(nreps)), 
                     ymax = .mean + (.sd / sqrt(nreps)), 
                     fill = group), 
        inherit.aes = FALSE, alpha = 0.2
      ) +
      ggplot2::labs(
        x = expression("Correlation ("*rho*")"), 
        y = "Average Rank", 
        color = "Feature\nGroup", 
        title = method
      ) +
      ggplot2::coord_cartesian(
        ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
      ) +
      ggplot2::scale_color_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::scale_fill_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::guides(fill = "none", linetype = "none") +
      vthemes::theme_vmodern(
        size_preset = "medium", bg_color = "white", grid_color = "white",
        axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
        axis.text = ggplot2::element_text(size = 14),
        axis.title = ggplot2::element_text(size = 18, face = "plain"),
        legend.text = ggplot2::element_text(size = 14),
        legend.title = ggplot2::element_text(size = 18),
        plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
      )
    if (method != method_labels[1]) {
      plt_ls[[method]] <- plt_ls[[method]] +
        ggplot2::theme(axis.title.y = ggplot2::element_blank(),
                       axis.ticks.y = ggplot2::element_blank(),
                       axis.text.y = ggplot2::element_blank())
    }
  }
  
  plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
  if (params$for_paper) {
    ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide.pdf", h)),
                    # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                    height = fig_height * .55, width = fig_width * .3 * length(show_methods))
  } else {
    vthemes::subchunkify(
      plt_wide, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
    )
    chunk_idx <- chunk_idx + 1
  }
}


#### Examine number of splits across varying correlation levels ####
cat("\n\n### Number of RF Splits \n\n")
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities,
                fi == "MDI_with_splits")

total_splits <- results %>%
  dplyr::group_by(rho_name, heritability, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

plt_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("rho_name", "heritability", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100,
    group = dplyr::case_when(
      var %in% sig_ids ~ "Sig",
      var %in% cnsig_ids ~ "C-NSig",
      TRUE ~ "NSig"
    )
  ) %>%
  dplyr::mutate(
    group = factor(group, levels = c("Sig", "C-NSig", "NSig"))
  ) %>%
  dplyr::group_by(rho_name, heritability, group) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

min_y <- min(plt_df$.mean)
max_y <- max(plt_df$.mean)

plt_ls <- list()
for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  plt <- plt_df %>%
    dplyr::filter(heritability == !!h) %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = rho_name, y = .mean, color = group) +
    ggplot2::geom_line(size = 1) +
    # ggplot2::geom_ribbon(
    #   ggplot2::aes(x = rho_name, 
    #                ymin = .mean - (.sd / sqrt(nreps)), 
    #                ymax = .mean + (.sd / sqrt(nreps)), 
    #                fill = group), 
    #   inherit.aes = FALSE, alpha = 0.2
    # ) +
    ggplot2::labs(
      x = expression("Correlation ("*rho*")"), 
      y = "Percentage of Splits in RF (%)",
      color = "Feature\nGroup", 
      title = sprintf("PVE = %s", h)
    ) +
    ggplot2::coord_cartesian(ylim = c(min_y, max_y)) +
    ggplot2::scale_color_manual(
      values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
    ) +
    ggplot2::scale_fill_manual(
      values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
    ) +
    ggplot2::guides(fill = "none", linetype = "none") +
    vthemes::theme_vmodern(
      size_preset = "medium", bg_color = "white", grid_color = "white",
      axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
      axis.text = ggplot2::element_text(size = 14),
      axis.title = ggplot2::element_text(size = 18, face = "plain"),
      legend.text = ggplot2::element_text(size = 14),
      legend.title = ggplot2::element_text(size = 18),
      plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
    )
  if (idx != 1) {
    plt <- plt +
      ggplot2::theme(axis.title.y = ggplot2::element_blank())
  }
  plt_ls[[idx]] <- plt
}

plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
if (params$for_paper) {
  ggplot2::ggsave(plt_wide & ggplot2::theme(plot.title = ggplot2::element_blank()), 
                  filename = file.path(figures_dir, "correlation_sim_num_splits.pdf"),
                  width = fig_width * 1, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```

## Appendix Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "gray")
show_methods <- c("GMDI", "GMDI_inbag")
method_labels <- c("MDI+ (LOO)", "MDI+ (in-bag)")

custom_color_palette <- c("#38761d", "#9dc89b", "#991500")
keep_heritabilities <- c(0.1, 0.4)
sim_name <- "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_rho",
  sprintf("seed%s", seed),
  "results.csv"
)
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name %in% keep_heritabilities)

n <- 250
p <- 100
sig_ids <- 0:5
cnsig_ids <- 6:49
nreps <- length(unique(results$rep))

#### Examine average rank across varying correlation levels ####
plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "rho_name",
  param_name = NULL,
  sig_ids = sig_ids,
  cnsig_ids = cnsig_ids,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

for (idx in 1:length(keep_heritabilities)) {
  h <- keep_heritabilities[idx]
  cat(sprintf("\n\n### PVE = %s\n\n", h))
  plt_df <- plt_base$agg[[idx]]$data
  plt_ls <- list()
  for (method in method_labels) {
    plt_ls[[method]] <- plt_df %>%
      dplyr::filter(fi == method) %>%
      ggplot2::ggplot() +
      ggplot2::aes(x = rho_name, y = .mean, color = group) +
      ggplot2::geom_line() +
      ggplot2::geom_ribbon(
        ggplot2::aes(x = rho_name, 
                     ymin = .mean - (.sd / sqrt(nreps)), 
                     ymax = .mean + (.sd / sqrt(nreps)), 
                     fill = group), 
        inherit.aes = FALSE, alpha = 0.2
      ) +
      ggplot2::labs(
        x = expression("Correlation ("*rho*")"), 
        y = "Average Rank", 
        color = "Feature\nGroup", 
        title = method
      ) +
      ggplot2::coord_cartesian(
        ylim = c(min(plt_df$.mean) - 3, max(plt_df$.mean) + 3)
      ) +
      ggplot2::scale_color_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::scale_fill_manual(
        values = custom_color_palette, guide = ggplot2::guide_legend(reverse = TRUE)
      ) +
      ggplot2::guides(fill = "none", linetype = "none") +
      vthemes::theme_vmodern(
        size_preset = "medium", bg_color = "white", grid_color = "white",
        axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
        axis.text = ggplot2::element_text(size = 14),
        axis.title = ggplot2::element_text(size = 18, face = "plain"),
        legend.text = ggplot2::element_text(size = 14),
        legend.title = ggplot2::element_text(size = 18),
        plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
      )
    if (method != method_labels[1]) {
      plt_ls[[method]] <- plt_ls[[method]] +
        ggplot2::theme(axis.title.y = ggplot2::element_blank(),
                       axis.ticks.y = ggplot2::element_blank(),
                       axis.text.y = ggplot2::element_blank())
    }
  }
  
  plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
  if (params$for_paper) {
    ggplot2::ggsave(plt_wide, 
                    filename = file.path(figures_dir, sprintf("correlation_sim_pve%s_wide_appendix.pdf", h)),
                    # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                    height = fig_height * .55, width = fig_width * .37 * length(show_methods))
  } else {
    vthemes::subchunkify(
      plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
    )
    chunk_idx <- chunk_idx + 1
  }
}
```


# Entropy Bias Simulations {.tabset .tabset-vmodern}

## Main Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
#### Entropy Regression Results ####
manual_color_palette <- rev(c("black", "orange", "#71beb7", "#218a1e", "#cc3399"))
show_methods <- rev(c("GMDI", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
method_labels <- rev(c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
alpha_values <- rev(c(1, rep(0.7, length(method_labels) - 1)))
y_limits <- c(1, 4.5)

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Regression Avg Rank ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_regression <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size",
    y = expression("Average Rank of "*X[1]),
    color = "Method",
    title = "Regression"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(ylim = y_limits) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
  )

#### Entropy Regression # Splits ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1, 
                fi == "MDI_with_splits")

nreps <- length(unique(results$rep))

total_splits <- results %>%
  dplyr::group_by(n_name, heritability, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

regression_num_splits_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("n_name", "heritability", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100
  ) %>%
  dplyr::group_by(n_name, heritability, var) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

#### Entropy Classification Results ####
manual_color_palette <- rev(c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399"))
show_methods <- rev(c("GMDI_ridge", "GMDI_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
method_labels <- rev(c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP"))
alpha_values <- rev(c(1, 1, rep(0.7, length(method_labels) - 2)))

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_c_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Classification Avg Rank ####
results <- data.table::fread(fname)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "c_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_classification <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    title = "Classification"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(
    ylim = y_limits
  ) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
  )

#### Entropy Classification # Splits ####
results <- data.table::fread(fname) %>%
  dplyr::filter(fi == "MDI_with_splits")

total_splits <- results %>%
  dplyr::group_by(n_name, c, rep) %>%
  dplyr::summarise(n_splits = sum(av_splits))

classification_num_splits_df <- results %>%
  dplyr::left_join(
    total_splits, by = c("n_name", "c", "rep")
  ) %>%
  dplyr::mutate(
    av_splits = av_splits / n_splits * 100
  ) %>%
  dplyr::group_by(n_name, c, var) %>%
  dplyr::summarise(
    .mean = mean(av_splits),
    .sd = sd(av_splits)
  )

#### Show Avg Rank Plot ####
cat("\n\n### Average Rank \n\n")
plt <- plt_regression + plt_classification
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims.pdf"), 
                  width = fig_width * 1.2, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt, i = chunk_idx, fig_height = fig_height * .8, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

#### Show # Splits Plot ####
cat("\n\n### Number of RF Splits \n\n")
plt_df <- dplyr::bind_rows(
  Regression = regression_num_splits_df, 
  Classification = classification_num_splits_df, 
  .id = "type"
) %>%
  dplyr::mutate(
    var = dplyr::case_when(
      var == 0 ~ "Bernoulli (Signal)",
      var == 1 ~ "Normal (Non-Signal)",
      var == 2 ~ "4-Categories (Non-Signal)",
      var == 3 ~ "10-Categories (Non-Signal)",
      var == 4 ~ "20-Categories (Non-Signal)"
    ) %>%
      factor(levels = c("Normal (Non-Signal)",
                        "20-Categories (Non-Signal)",
                        "10-Categories (Non-Signal)",
                        "4-Categories (Non-Signal)",
                        "Bernoulli (Signal)"))
  )

sim_types <- unique(plt_df$type)
plt_ls <- list()
for (idx in 1:length(sim_types)) {
  sim_type <- sim_types[idx]
  plt <- plt_df %>%
    dplyr::filter(type == !!sim_type) %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = n_name, y = .mean, color = as.factor(var)) +
    ggplot2::geom_line(size = 1) +
    ggplot2::labs(
      x = "Sample Size",
      y = "Percentage of Splits in RF (%)",
      color = "Feature", 
      title = sim_type
    ) +
    ggplot2::guides(fill = "none", linetype = "none") +
    vthemes::theme_vmodern(
      size_preset = "medium", bg_color = "white", grid_color = "white",
      axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
      axis.text = ggplot2::element_text(size = 14),
      axis.title = ggplot2::element_text(size = 18, face = "plain"),
      legend.text = ggplot2::element_text(size = 14),
      legend.title = ggplot2::element_text(size = 18),
      plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5)
    )
  if (idx != 1) {
    plt <- plt +
      ggplot2::theme(axis.title.y = ggplot2::element_blank())
  }
  plt_ls[[idx]] <- plt
}

plt_wide <- patchwork::wrap_plots(plt_ls, guides = "collect", nrow = 1)
if (params$for_paper) {
  ggplot2::ggsave(plt_wide, filename = file.path(figures_dir, "entropy_sims_num_splits.pdf"),
                  # height = fig_height * .55, width = fig_width * .25 * length(show_methods))
                  width = fig_width * 1.1, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt_wide, i = chunk_idx, fig_height = fig_height, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```

## Appendix Figures {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
#### Entropy Regression Results ####
manual_color_palette <- rev(c("black", "black", "#71beb7"))
show_methods <- rev(c("GMDI", "GMDI_inbag", "MDI"))
method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI"))
alpha_values <- rev(c(1, 1, rep(1, length(method_labels) - 2)))
linetype_values <- c("solid", "dashed", "solid")
y_limits <- c(1, 5)

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_heritability_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Regression Avg Rank ####
results <- data.table::fread(fname) %>%
  dplyr::filter(heritability_name == 0.1)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "heritability_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_regression <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    linetype = "Method",
    title = "Regression"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_linetype_manual(values = linetype_values, 
                                 labels = method_labels,
                                 guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(ylim = y_limits) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
    legend.key.width = ggplot2::unit(1, "cm")
  )

#### Entropy Classification Results ####
manual_color_palette <- rev(c("black", "black", "#9B5DFF", "#9B5DFF", "#71beb7"))
show_methods <- rev(c("GMDI_ridge", "GMDI_ridge_inbag", "GMDI_logistic_logloss", "GMDI_logistic_logloss_inbag", "MDI"))
method_labels <- rev(c("MDI+ (ridge, LOO)", "MDI+ (ridge, in-bag)", "MDI+ (logistic, LOO)", "MDI+ (logistic, in-bag)", "MDI"))
alpha_values <- rev(c(1, 1, 1, 1, rep(1, length(method_labels) - 4)))
linetype_values <- c("solid", "dashed", "solid", "dashed", "solid")

metric_name <- "AUROC"
sim_name <- "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp"
fname <- file.path(
  results_dir, 
  sim_name,
  "varying_c_n",
  sprintf("seed%s", seed),
  "results.csv"
)

#### Entropy Classification Avg Rank ####
results <- data.table::fread(fname)

custom_group_fun <- function(var) var

plt_base <- plot_perturbation_stability(
  results,
  facet_rows = "c_name",
  facet_cols = "n_name", 
  group_fun = custom_group_fun, 
  param_name = NULL,
  plot_types = "errbar",
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

plt_df <- plt_base$agg[[1]]$data
nreps <- length(unique(results$rep))

plt_classification <- plt_df %>%
  dplyr::filter(group == 0) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = n_name, y = .mean, color = fi, alpha = fi, linetype = fi) +
  ggplot2::geom_line(size = line_size) +
  ggplot2::labs(
    x = "Sample Size", 
    y = expression("Average Rank of "*X[1]), 
    color = "Method",
    linetype = "Method",
    title = "Classification"
  ) +
  ggplot2::guides(fill = "none", alpha = "none") +
  ggplot2::scale_color_manual(values = manual_color_palette,
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_alpha_manual(values = alpha_values, 
                              labels = method_labels,
                              guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::scale_linetype_manual(values = linetype_values,
                                 labels = method_labels,
                                 guide = ggplot2::guide_legend(reverse = TRUE)) +
  ggplot2::coord_cartesian(
    ylim = y_limits
  ) +
  vthemes::theme_vmodern(
    size_preset = "medium", bg_color = "white", grid_color = "white",
    axis.ticks = ggplot2::element_line(size = ggplot2::rel(2), colour = "black"),
    axis.text = ggplot2::element_text(size = 14),
    axis.title = ggplot2::element_text(size = 18, face = "plain"),
    legend.text = ggplot2::element_text(size = 14),
    legend.title = ggplot2::element_text(size = 18),
    plot.title = ggplot2::element_text(size = 20, face = "bold", hjust = 0.5),
    legend.key.width = ggplot2::unit(1, "cm")
  )

#### Show Avg Rank Plot ####
cat("\n\n### Average Rank \n\n")
plt <- plt_regression + plt_classification
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "entropy_sims_appendix.pdf"), 
                  width = fig_width * 1.35, height = fig_height * .65)
} else {
  vthemes::subchunkify(
    plt, i = chunk_idx, fig_height = fig_height * .7, fig_width = fig_width * 1.3
  )
  chunk_idx <- chunk_idx + 1
}

```


# CCLE Case Study {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
show_methods <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))

fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "results.csv"
)
pred_fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "pred_results.csv"
)
X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
X_columns <- colnames(X)
results <- data.table::fread(fname)
pred_results <- data.table::fread(pred_fname) %>%
  tibble::as_tibble() %>%
  dplyr::mutate(method = fi) %>%
  # tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>%
  tidyr::pivot_wider(names_from = metric, values_from = metric_value)

out <- plot_top_stability(
  results,
  group_id = "y_task",
  top_r = 10,
  show_max_features = 5,
  varnames = colnames(X),
  base_method = "MDI+ (ridge)",
  return_df = TRUE,
  manual_color_palette = rev(manual_color_palette),
  show_methods = rev(show_methods),
  method_labels = rev(method_labels)
)

ranking_df <- out$rankings
stability_df <- out$stability
plt_ls <- out$plot_ls

# get gene names instead of ENSG ids
library(EnsDb.Hsapiens.v79)
varnames <- ensembldb::select(
  EnsDb.Hsapiens.v79, 
  keys = stringr::str_remove(colnames(X), "\\.[^\\.]+$"), 
  keytype = "GENEID", columns = c("GENEID","SYMBOL")
)

# get top 5 genes
top_genes <- ranking_df %>%
  dplyr::group_by(y_task, fi, var) %>%
  dplyr::summarise(mean_rank = mean(rank)) %>%
  dplyr::ungroup() %>%
  dplyr::arrange(y_task, mean_rank) %>%
  dplyr::group_by(y_task, fi) %>%
  dplyr::mutate(rank = 1:dplyr::n()) %>%
  dplyr::ungroup()

top_genes_table <- top_genes %>%
  dplyr::filter(rank <= 50) %>%
  dplyr::left_join(
    y = data.frame(var = 0:(ncol(X) - 1), 
                   GENEID = stringr::str_remove(colnames(X), "\\.[^\\.]+$")),
    by = "var"
  ) %>%
  dplyr::left_join(
    y = varnames, by = "GENEID"
  ) %>%
  dplyr::mutate(
    value = sprintf("%s (%s)", SYMBOL, round(mean_rank, 2))
  ) %>%
  tidyr::pivot_wider(
    id_cols = c(y_task, rank), 
    names_from = fi, 
    values_from = value
  ) %>%
  dplyr::rename(
    "Drug" = "y_task",
    "Rank" = "rank"
  )

# write.csv(
#   top_genes_table, file.path(tables_dir, "ccle_rnaseq_top_genes_table.csv"),
#   row.names = FALSE, quote = FALSE
# )

# get # genes with perfect stability
stable_genes_kable <- stability_df %>%
  dplyr::group_by(fi, y_task) %>%
  dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
  tidyr::pivot_wider(id_cols = fi, names_from = y_task, values_from = stability_1) %>%
  dplyr::ungroup() %>%
  as.data.frame()

pred_plt_ls <- list()
for (plt_id in names(plt_ls)) {
  cat(sprintf("\n\n## %s \n\n", plt_id))
  plt <- plt_ls[[plt_id]]
  if (plt_id == "Summary") {
    plt <- plt + ggplot2::labs(y = "Drug")
    if (params$for_paper) {
      ggplot2::ggsave(plt, filename = file.path(figures_dir, "ccle_rnaseq_stability_top10.pdf"), 
                      width = fig_width * 1.5, height = fig_height * 1.1)
    } else {
      vthemes::subchunkify(
        plt, i = sprintf("%s-%s", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
    }
    
    cat("\n\n## Table of Top Genes \n\n")
    if (params$for_paper) {
      top_genes_kable <- top_genes_table %>%
        dplyr::filter(Rank <= 5) %>%
        dplyr::select(-Drug) %>%
        dplyr::rename(" " = "Rank") %>%
        vthemes::pretty_kable(format = "latex", longtable = TRUE) %>%
        kableExtra::kable_styling(latex_options = "repeat_header") %>%
        # kableExtra::collapse_rows(columns = 1, valign = "middle")
        kableExtra::pack_rows(
          index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
        )
    } else {
      top_genes_kable <- top_genes_table %>%
        dplyr::filter(Rank <= 5) %>%
        dplyr::select(-Drug) %>%
        dplyr::rename(" " = "Rank") %>%
        vthemes::pretty_kable() %>%
        kableExtra::kable_styling(latex_options = "repeat_header") %>%
        # kableExtra::collapse_rows(columns = 1, valign = "middle")
        kableExtra::pack_rows(
          index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
        )
      vthemes::subchunkify(
        top_genes_kable, i = sprintf("%s-table", fpath)
      )
    }
  } else {
    pred_plt_ls[[plt_id]] <- pred_results %>%
      dplyr::filter(k <= 10, y_task == !!plt_id) %>%
      plot_metrics(
        metric = "r2",
        x_str = "k",
        facet_str = NULL,
        point_size = point_size,
        line_size = line_size,
        errbar_width = errbar_width,
        manual_color_palette = manual_color_palette,
        show_methods = show_methods,
        method_labels = method_labels,
        alpha_values = alpha_values,
        inside_legend = FALSE
      ) +
      ggplot2::labs(x = "Number of Top Features (k)", y = "Test R-squared",
                    color = "Method", alpha = "Method", title = plt_id) +
      vthemes::theme_vmodern()
    
    if (!params$for_paper) {
      vthemes::subchunkify(
        plot_fun(plt), i = sprintf("%s-%s", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
      vthemes::subchunkify(
        plot_fun(pred_plt_ls[[plt_id]]), i = sprintf("%s-%s-pred", fpath, plt_id), 
        fig_height = fig_height, fig_width = fig_width * 1.3
      )
    }
  }
}

if (params$for_paper) {
  all_plt_ls <- list()
  for (idx in 1:length(pred_plt_ls)) {
    drug <- names(pred_plt_ls)[idx]
    plt <- pred_plt_ls[[drug]]
    if ((idx - 1) %% 4 != 0) {
      plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
    }
    if ((idx - 1) %/% 4 != 5) {
      plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
    }
    all_plt_ls[[drug]] <- plt
  }
  plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
    ggplot2::theme(
      plot.title = ggplot2::element_text(hjust = 0.5),
      panel.background = ggplot2::element_rect(fill = "white"),
      panel.grid.major = ggplot2::element_line(color = "white")
    )
  ggplot2::ggsave(plt, 
                  filename = file.path(figures_dir, "ccle_rnaseq_predictions.pdf"),
                  width = fig_width * 1, height = fig_height * 2)
  
  pred10_ranks <- pred_results %>% 
    dplyr::filter(k == 10) %>% 
    dplyr::group_by(fi, y_task) %>% 
    dplyr::summarise(r2 = mean(r2)) %>% 
    tidyr::pivot_wider(names_from = y_task, values_from = r2) %>% 
    dplyr::ungroup() %>% 
    dplyr::mutate(dplyr::across(`17-AAG`:`ZD-6474`, ~rank(-.x)))
  pred10_ranks_table <- apply(pred10_ranks, 1, FUN = function(x) table(x[-1])) %>%
    tibble::as_tibble() %>%
    tibble::rownames_to_column("Rank") %>%
    setNames(c("Rank", pred10_ranks$fi)) %>%
    dplyr::select(Rank, `MDI+`, TreeSHAP, MDI, `MDI-oob`, MDA) %>%
    vthemes::pretty_kable(format = "latex")
  top_genes_kable <- top_genes_table %>%
    dplyr::filter(Rank <= 5) %>%
    dplyr::select(-Drug) %>%
    dplyr::rename(" " = "Rank") %>%
    vthemes::pretty_kable() %>%
    kableExtra::kable_styling(latex_options = "repeat_header") %>%
    # kableExtra::collapse_rows(columns = 1, valign = "middle")
    kableExtra::pack_rows(
      index = rep(5, 24) %>% setNames(unique(top_genes_table$Drug))
    )
  vthemes::subchunkify(
    top_genes_kable, i = sprintf("%s-table", fpath)
  )
}
```

```{r eval = TRUE, results = "asis"}
if (params$for_paper) {
  manual_color_palette_full <- c("black", "orange", "#71beb7", "#218a1e", "#cc3399")
  show_methods_full <- c("MDI+", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  method_labels_full <- c("MDI+ (ridge)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  keep_methods_list <- list(
    all = show_methods_full,
    no_mdioob = c("MDI+", "MDA", "MDI", "TreeSHAP"),
    zoom = c("MDI+", "MDI", "TreeSHAP")
  )
  
  fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
  results <- data.table::fread(fname)
  
  for (id in names(keep_methods_list)) {
    keep_methods <- keep_methods_list[[id]]
    manual_color_palette <- manual_color_palette_full[show_methods_full %in% keep_methods]
    show_methods <- show_methods_full[show_methods_full %in% keep_methods]
    method_labels <- method_labels_full[show_methods_full %in% keep_methods]
    
    plt_ls <- plot_top_stability(
      results,
      group_id = "y_task",
      top_r = 10,
      show_max_features = 5,
      varnames = colnames(X),
      base_method = "MDI+",
      return_df = FALSE,
      manual_color_palette = manual_color_palette,
      show_methods = show_methods,
      method_labels = method_labels
    )
    all_plt_ls <- list()
    for (idx in 1:length(unique(results$y_task))) {
      drug <- sort(unique(results$y_task))[idx]
      plt <- plt_ls[[drug]]
      if ((idx - 1) %% 4 != 0) {
        plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
      }
      if ((idx - 1) %/% 4 != 5) {
        plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
      }
      all_plt_ls[[drug]] <- plt
    }
    plt <- patchwork::wrap_plots(all_plt_ls, guides = "collect", ncol = 4, nrow = 6) &
      ggplot2::theme(
        plot.title = ggplot2::element_text(hjust = 0.5),
        panel.background = ggplot2::element_rect(fill = "white"),
        panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
      )
    ggplot2::ggsave(plt, 
                    filename = file.path(figures_dir, sprintf("ccle_rnaseq_stability_select_features_%s.pdf", id)),
                    width = fig_width * 1, height = fig_height * 2)
  }
}
```


# TCGA BRCA Case Study {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#cc3399")
show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "TreeSHAP")
alpha_values <- c(1, 1, rep(0.4, length(method_labels) - 2))

fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "results.csv"
)
X <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_tcga_cleaned.csv")
results <- data.table::fread(fname)
out <- plot_top_stability(
  results,
  group_id = NULL,
  top_r = 10,
  show_max_features = 20,
  varnames = colnames(X),
  base_method = "MDI+ (ridge)", 
  # descending_methods = "MDI+ (logistic)",
  return_df = TRUE,
  manual_color_palette = manual_color_palette,
  show_methods = show_methods,
  method_labels = method_labels
)

ranking_df <- out$rankings
stability_df <- out$stability
plt_ls <- out$plot_ls

cat("\n\n## Non-zero Stability Scores Per Method\n\n")
vthemes::subchunkify(
  plt_ls[[1]], i = sprintf("%s-%s", fpath, 1), 
  fig_height = fig_height * round(length(unique(results$fi)) / 2), fig_width = fig_width
)

cat("\n\n## Stability of Top Features\n\n")
vthemes::subchunkify(
  plot_fun(plt_ls[[2]]), i = sprintf("%s-%s", fpath, 2), 
  fig_height = fig_height, fig_width = fig_width
)

# get top 25 genes
top_genes <- ranking_df %>%
  dplyr::group_by(fi, var) %>%
  dplyr::summarise(mean_rank = mean(rank)) %>%
  dplyr::ungroup() %>%
  dplyr::arrange(mean_rank) %>%
  dplyr::group_by(fi) %>%
  dplyr::mutate(rank = 1:dplyr::n()) %>%
  dplyr::ungroup()

top_genes_table <- top_genes %>%
  dplyr::filter(rank <= 25) %>%
  dplyr::left_join(
    y = data.frame(var = 0:(ncol(X) - 1), 
                   gene = colnames(X)),
    by = "var"
  ) %>%
  dplyr::mutate(
    value = sprintf("%s (%s)", gene, round(mean_rank, 2))
  ) %>%
  tidyr::pivot_wider(
    id_cols = rank, 
    names_from = fi, 
    values_from = value
  ) %>%
  dplyr::rename(
    "Rank" = "rank"
  )

# write.csv(
#   top_genes_table, file.path(tables_dir, "tcga_top_genes_table.csv"),
#   row.names = FALSE, quote = FALSE
# )

stable_genes_kable <- stability_df %>%
  dplyr::group_by(fi) %>%
  dplyr::summarise(stability_1 = sum(stability_score == 1)) %>%
  tibble::as_tibble()

cat("\n\n## Table of Top Genes \n\n")
if (params$for_paper) {
  top_genes_kable <- top_genes_table %>%
    vthemes::pretty_kable(format = "latex")#, longtable = TRUE) %>%
    # kableExtra::kable_styling(latex_options = "repeat_header")
} else {
  top_genes_kable <- top_genes_table %>%
    vthemes::pretty_kable() %>%
    kableExtra::kable_styling(latex_options = "repeat_header")
  vthemes::subchunkify(
    top_genes_kable, i = sprintf("%s-table", fpath)
  )
}

cat("\n\n## Prediction using Top Features\n\n")
pred_fname <- file.path(
  results_dir, 
  fpath,
  sprintf("seed%s", seed),
  "pred_results.csv"
)
pred_results <- data.table::fread(pred_fname) %>%
  tibble::as_tibble() %>%
  dplyr::mutate(method = fi) %>%
  tidyr::pivot_wider(names_from = metric, values_from = metric_value)
plt <- pred_results %>%
  dplyr::filter(k <= 25) %>%
  plot_metrics(
    metric = c("rocauc", "accuracy"),
    x_str = "k",
    facet_str = NULL,
    point_size = point_size,
    line_size = line_size,
    errbar_width = errbar_width,
    manual_color_palette = manual_color_palette,
    show_methods = show_methods,
    method_labels = method_labels,
    alpha_values = alpha_values,
    custom_theme = custom_theme,
    inside_legend = FALSE
  ) +
  ggplot2::labs(x = "Number of Top Features (k)", y = "Mean Prediction Performance", 
                color = "Method", alpha = "Method")
if (params$for_paper) {
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "tcga_brca_prediction.pdf"),
                  width = fig_width * 1.3, height = fig_height * .85)
} else {
  vthemes::subchunkify(
    plot_fun(plt), i = sprintf("%s-pred", fpath), 
    fig_height = fig_height * .8, fig_width = fig_width * 1.5
  )
}
```

```{r eval = TRUE, results = "asis"}
if (params$for_paper) {
  manual_color_palette <- c("black", "#9B5DFF", "orange", "#71beb7", "#218a1e", "#cc3399")
  show_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  method_labels <- c("MDI+ (ridge)", "MDI+ (logistic)", "MDA", "MDI", "MDI-oob", "TreeSHAP")
  
  # ccle rnaseq
  fpath <- "mdi_plus.real_data_case_study.ccle_rnaseq_regression-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X_ccle <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_ccle_rnaseq_cleaned_filtered5000.csv")
  results_ccle <- data.table::fread(fname) %>%
    dplyr::mutate(
      fi = ifelse(fi == "MDI+", "MDI+_ridge", fi)
    )
  
  # tcga
  fpath <- "mdi_plus.real_data_case_study.tcga_brca_classification-"
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  X_tcga <- data.table::fread("/global/scratch/users/tiffanytang/feature_importance/data/X_tcga_cleaned.csv")
  results_tcga <- data.table::fread(fname)
  
  # join results
  keep_cols <- c("rep", "y_task", "model", "fi", "index", "var", "importance")
  results <- dplyr::bind_rows(
    results_ccle %>% dplyr::select(tidyselect::all_of(keep_cols)),
    results_tcga %>% dplyr::mutate(y_task = "TCGA-BRCA") %>% dplyr::select(tidyselect::all_of(keep_cols))
  ) %>%
    tibble::as_tibble()
  plt_ls <- plot_top_stability(
    results,
    group_id = "y_task",
    top_r = 10,
    show_max_features = 5,
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = rev(manual_color_palette),
    show_methods = rev(show_methods),
    method_labels = rev(method_labels)
  )
  
  plt <- plt_ls$Summary +
    ggplot2::labs(
      y = "Drug",
      x = "Number of Distinct Features in Top 10 Across 32 Training-Test Splits"
    ) +
    ggplot2::scale_x_continuous(
      n.breaks = 6
    )
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_top10.pdf"), 
                        width = fig_width * 1.5, height = fig_height * 1.2)
  
  keep_methods <- c("MDI+_ridge", "MDI+_logistic_logloss", "MDA", "MDI", "TreeSHAP")
  manual_color_palette_small <- manual_color_palette[show_methods %in% keep_methods]
  show_methods_small <- show_methods[show_methods %in% keep_methods]
  method_labels_small <- method_labels[show_methods %in% keep_methods]
  
  plt_tcga <- plot_top_stability(
    results_tcga,
    top_r = 10,
    show_max_features = 5,
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = manual_color_palette_small,
    show_methods = show_methods_small,
    method_labels = method_labels_small
  )$`Stability of Top Features` +
    ggplot2::labs(title = "TCGA-BRCA") + 
    ggplot2::theme(axis.title.y = ggplot2::element_blank())
  
  plt_ls <- plot_top_stability(
    results_ccle,
    group_id = "y_task",
    top_r = 10,
    show_max_features = 5,
    # varnames = colnames(X),
    base_method = "MDI+ (ridge)",
    return_df = FALSE,
    manual_color_palette = manual_color_palette_small[show_methods_small != "MDI+_logistic_logloss"],
    show_methods = show_methods_small[show_methods_small != "MDI+_logistic_logloss"],
    method_labels = method_labels_small[show_methods_small != "MDI+_logistic_logloss"]
  )
  
  # keep_drugs <- c("Panobinostat", "Lapatinib", "L-685458")
  keep_drugs <- c("PD-0325901", "Panobinostat", "L-685458")
  small_plt_ls <- list()
  for (drug in keep_drugs) {
    plt <- plt_ls[[drug]]
    if (drug != keep_drugs[1]) {
      plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
    }
    small_plt_ls[[drug]] <- plt +
      ggplot2::guides(fill = "none")
  }
  small_plt_ls[["TCGA-BRCA"]] <- plt_tcga
  # small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + small_plt_ls[[4]] +
  #   patchwork::plot_layout(nrow = 1, ncol = 4, guides = "collect")
  plt <- patchwork::wrap_plots(small_plt_ls, nrow = 1, ncol = 4, guides = "collect") &
    ggplot2::theme(
      plot.title = ggplot2::element_text(hjust = 0.5),
      panel.background = ggplot2::element_rect(fill = "white"),
      panel.grid.major = ggplot2::element_line(color = "grey95", size = ggplot2::rel(0.25))
    )
  ggplot2::ggsave(plt, filename = file.path(figures_dir, "casestudy_stability_select_features.pdf"),
                  width = fig_width * 1, height = fig_height * 0.5)
  
  plt <- small_plt_ls[[1]] + small_plt_ls[[2]] + small_plt_ls[[3]] + patchwork::plot_spacer() + small_plt_ls[[4]] +
    patchwork::plot_layout(nrow = 1, widths = c(1, 1, 1, 0.1, 1), guides = "collect")
  # plt
  # grid::grid.draw(
  #   grid::linesGrob(x = grid::unit(c(0.68, 0.68), "npc"),
  #                   y = grid::unit(c(0.06, 0.98), "npc"))
  # )
  # # ggplot2::ggsave(
  # #   file.path(figures_dir, "casestudy_stability_select_features.pdf"),
  # #   units = "in", width = 14, height = 4
  # # )
}

```


# Misspecified Models {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.35)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")

remove_x_axis_models <- c("enhancer", "ccle_rnaseq", "juvenile")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp_omitted_vars", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.regression_sims.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "misspecified_regression_sims", 
                  sprintf("misspecified_regression_sims_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_row_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Sample Size", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Varying Sparsity {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.7)

y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("juvenile", "splicing")

remove_x_axis_models <- c("splicing")
keep_legend_x_models <- c("splicing")
keep_legend_y_models <- c("linear", "linear_lss_3m_2r")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    x_str <- ifelse(y_model == "linear", "s", "m")
    vary_param_name <- sprintf("heritability_%s", x_str)
    fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_sparsity.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = x_str,
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = toupper(x_str), y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (x_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "varying_sparsity", 
                  sprintf("regression_sims_sparsity_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = 3
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = x_str,
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = toupper(x_str), y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Varying \# Features {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
keep_methods <- color_df$name %in% c("GMDI_ridge_RF", "MDA_RF", "MDI_RF", "MDI-oob_RF", "TreeSHAP_RF")
manual_color_palette <- color_df$color[keep_methods]
show_methods <- color_df$name[keep_methods]
method_labels <- color_df$label[keep_methods]
alpha_values <- c(1, rep(0.4, length(method_labels) - 1))
legend_position <- c(0.73, 0.4)

vary_param_name <- "heritability_sample_col_n"
y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
x_models <- c("ccle_rnaseq")

remove_x_axis_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r")
keep_legend_x_models <- c("ccle_rnaseq")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s \n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    metric_name <- dplyr::case_when(
      metric == "rocauc" ~ "AUROC",
      metric == "prauc" ~ "PRAUC"
    )
    fname <- file.path(results_dir, paste0("mdi_plus.other_regression_sims.varying_p.", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    if (params$for_paper) {
      for (h in heritabilities) {
        plt <- results %>%
          dplyr::filter(heritability == !!h) %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_col_n",
            facet_str = NULL,
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            legend_position = legend_position,
            custom_theme = custom_theme,
            inside_legend = TRUE
          ) +
          ggplot2::labs(
            x = "Number of Features", y = metric_name, 
            color = "Method", alpha = "Method"
          )
        if (h != heritabilities[1]) {
          plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
        }
        if (y_model %in% remove_x_axis_models) {
          height <- 2.72
          plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
        } else {
          height <- 3
        }
        if (!((h == heritabilities[length(heritabilities)]) & 
              (x_model %in% keep_legend_x_models) & 
              (y_model %in% keep_legend_y_models))) {
          plt <- plt + ggplot2::guides(color = "none", alpha = "none")
        }
        plt_ls[[as.character(h)]] <- plt
      }
      plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
      ggplot2::ggsave(
        file.path(figures_dir, "varying_p", 
                  sprintf("regression_sims_vary_p_%s_%s_%s_errbars.pdf",
                          y_model, x_model, metric)),
        plot = plt, units = "in", width = 14, height = height
      )
    } else {
      plt <- results %>%
        plot_metrics(
          metric = metric,
          x_str = "sample_col_n",
          facet_str = "heritability_name",
          point_size = point_size,
          line_size = line_size,
          errbar_width = errbar_width,
          manual_color_palette = manual_color_palette,
          show_methods = show_methods,
          method_labels = method_labels,
          alpha_values = alpha_values,
          custom_theme = custom_theme
        ) +
        ggplot2::labs(
          x = "Number of Features", y = metric_name,
          color = "Method", alpha = "Method",
          title = sprintf("%s", sim_title)
        )
      vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                           fig_height = fig_height, fig_width = fig_width * 1.3)
      chunk_idx <- chunk_idx + 1
    }
  }
}

```


# Prediction Results {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
fpaths <- c(
  "mdi_plus.prediction_sims.ccle_rnaseq_regression-",
  "mdi_plus.prediction_sims.enhancer_classification-",
  "mdi_plus.prediction_sims.splicing_classification-",
  "mdi_plus.prediction_sims.juvenile_classification-",
  "mdi_plus.prediction_sims.tcga_brca_classification-"
)

keep_models <- c("RF", "RF-ridge", "RF-lasso", "RF-logistic")
prediction_metrics <- c(
  "r2", "explained_variance", "mean_squared_error", "mean_absolute_error",
  "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
)

results_ls <- list()
for (fpath in fpaths) {
  fname <- file.path(
    results_dir, 
    fpath,
    sprintf("seed%s", seed),
    "results.csv"
  )
  results <- data.table::fread(fname) %>%
    reformat_results(prediction = TRUE)
  
  if (!("y_task" %in% colnames(results))) {
    results <- results %>%
      dplyr::mutate(y_task = "Results")
  }
  
  plt_df <- results %>%
    tidyr::pivot_longer(
      cols = tidyselect::any_of(prediction_metrics), 
      names_to = "metric", values_to = "value"
    ) %>% 
    dplyr::group_by(y_task, model, metric) %>%
    dplyr::summarise(
      mean = mean(value),
      sd = sd(value),
      n = dplyr::n()
    ) %>%
    dplyr::mutate(
      se = sd / sqrt(n)
    )
  
  if (fpath == "mdi_plus.prediction_sims.ccle_rnaseq_regression") {
    tab <- plt_df %>%
      dplyr::mutate(
        value = sprintf("%.3f (%.3f)", mean, se)
      ) %>%
      tidyr::pivot_wider(id_cols = model, names_from = "metric", values_from = "value") %>%
      dplyr::arrange(dplyr::desc(accuracy)) %>%
      dplyr::select(model, accuracy, f1, rocauc, prauc, precision, recall) %>%
      vthemes::pretty_kable(format = "latex")
  }
  
  if (fpath != "mdi_plus.prediction_sims.ccle_rnaseq_regression-") {
    results_ls[[fpath]] <- plt_df %>%
      dplyr::filter(model %in% c("RF", "RF-logistic")) %>%
      dplyr::select(model, metric, mean) %>%
      tidyr::pivot_wider(id_cols = metric, names_from = "model", values_from = "mean") %>%
      dplyr::mutate(diff = `RF-logistic` - RF,
                    percent_diff = (`RF-logistic` - RF) / abs(RF) * 100)
  } else {
    results_ls[[fpath]] <- plt_df %>%
      dplyr::filter(model %in% c("RF", "RF-ridge")) %>%
      dplyr::select(y_task, model, metric, mean) %>%
      tidyr::pivot_wider(id_cols = c(y_task, metric), names_from = "model", values_from = "mean") %>%
      dplyr::mutate(diff = `RF-ridge` - RF,
                    percent_diff = (`RF-ridge` - RF) / abs(RF) * 100)
  }
}

plt_reg <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
  dplyr::filter(metric == "r2") %>%
  dplyr::filter(RF > 0.1) %>%
  dplyr::mutate(
    metric = ifelse(metric == "r2", "R-squared", metric)
  ) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
  ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
  ggplot2::geom_point(size = 3) +
  ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
  ggrepel::geom_label_repel(fill = "white") +
  ggplot2::facet_wrap(~ metric, scales = "free") +
  custom_theme

plt_reg_all <- results_ls$`mdi_plus.prediction_sims.ccle_rnaseq_regression-` %>%
  dplyr::filter(metric == "r2") %>%
  dplyr::mutate(
    metric = ifelse(metric == "r2", "R-squared", metric)
  ) %>%
  ggplot2::ggplot() +
  ggplot2::aes(x = RF, y = percent_diff, label = y_task) +
  ggplot2::labs(x = "RF Test R-squared", y = "% Change Using RF+ (ridge)") +
  ggplot2::geom_point(size = 3) +
  ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
  ggrepel::geom_label_repel(fill = "white") +
  ggplot2::facet_wrap(~ metric, scales = "free") +
  custom_theme

if (params$for_paper) {
  ggplot2::ggsave(
    plot = plt_reg_all, 
    file.path(figures_dir, "prediction_results_appendix.pdf"),
    units = "in", width = 8, height = fig_height * 0.75
  )
}

plt_ls <- list(reg = plt_reg)
for (m in c("f1", "prauc")) {
  plt_df <- dplyr::bind_rows(results_ls, .id = "dataset") %>%
    dplyr::filter(metric == m) %>%
    dplyr::mutate(
      dataset = stringr::str_remove(dataset, "^mdi_plus\\.prediction_sims\\.") %>%
        stringr::str_remove("_classification-$") %>%
        stringr::str_to_title() %>%
        stringr::str_replace("Tcga_brca", "TCGA BRCA"),
      diff = ifelse(metric == "logloss", -diff, diff),
      percent_diff = ifelse(metric == "logloss", -percent_diff, percent_diff),
      metric = forcats::fct_recode(metric,
                                   "Accuracy" = "accuracy",
                                   "F1" = "f1",
                                   "Negative Log-loss" = "logloss",
                                   "AUPRC" = "prauc",
                                   "AUROC" = "rocauc")
    )
  plt_ls[[m]] <- plt_df %>%
    ggplot2::ggplot() +
    ggplot2::aes(x = RF, y = percent_diff, label = dataset) +
    ggplot2::labs(x = sprintf("RF Test %s", plt_df$metric[1]), 
                  y = "% Change Using RF+ (logistic)") +
    ggplot2::geom_point(size = 3) +
    ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
    ggrepel::geom_label_repel(point.padding = 0.5, fill = "white") +
    ggplot2::facet_wrap(~ metric, scales = "free", nrow = 1) +
    custom_theme
}

plt_ls[[3]] <- plt_ls[[3]] +
  ggplot2::theme(axis.title.y = ggplot2::element_blank())

plt <- plt_ls[[1]] + patchwork::plot_spacer() + plt_ls[[2]] + plt_ls[[3]] +
  patchwork::plot_layout(nrow = 1, widths = c(1, 0.3, 1, 1))

# plt
# grid::grid.draw(
#   grid::linesGrob(x = grid::unit(c(0.36, 0.36), "npc"),
#                   y = grid::unit(c(0.06, 0.98), "npc"))
# )
# # ggplot2::ggsave(
# #   file.path(figures_dir, "prediction_results_main.pdf"),
# #   units = "in", width = 14, height = fig_height * 0.75
# # )

vthemes::subchunkify(plt, i = chunk_idx,
                     fig_height = fig_height, fig_width = fig_width * 1.3)
chunk_idx <- chunk_idx + 1
```


# MDI+ Modeling Choices {.tabset .tabset-vmodern}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c(manual_color_palette_choices[method_labels_choices == "MDI"],
                          manual_color_palette_choices[method_labels_choices == "MDI-oob"],
                          "#AF4D98", "#FFD92F", "#FC8D62", "black")
show_methods <- c("MDI_RF", "MDI-oob_RF", "GMDI_raw_RF", "GMDI_loo_RF", "GMDI_ols_raw_loo_RF", "GMDI_ridge_raw_loo_RF")
method_labels <- c("MDI", "MDI-oob", "MDI+ (raw only)", "MDI+ (loo only)", "MDI+ (raw+loo only)", "MDI+ (ridge+raw+loo)")
alpha_values <- rep(1, length(method_labels))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
min_samples_per_leafs <- c(5, 1)

remove_x_axis_models <- c("enhancer")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    plt_ls <- list()
    for (min_samples_per_leaf in min_samples_per_leafs) {
      cat(sprintf("\n\n#### min_samples_per_leaf = %s \n\n", min_samples_per_leaf))
      sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
      sim_title <- dplyr::case_when(
        x_model == "ccle_rnaseq" ~ "CCLE",
        x_model == "splicing" ~ "Splicing",
        x_model == "enhancer" ~ "Enhancer",
        x_model == "juvenile" ~ "Juvenile"
      )
      metric_name <- dplyr::case_when(
        metric == "rocauc" ~ "AUROC",
        metric == "prauc" ~ "PRAUC"
      )
      fname <- file.path(results_dir, 
                         sprintf("mdi_plus.other_regression_sims.modeling_choices_min_samples%s.%s", min_samples_per_leaf, sim_name),
                         paste0("varying_", vary_param_name), 
                         paste0("seed", seed), "results")
      if (!file.exists(sprintf("%s.csv", fname))) {
        next
      }
      if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
        results <- readRDS(sprintf("%s.rds", fname))
        if (length(setdiff(show_methods, unique(results$method))) > 0) {
          results <- data.table::fread(sprintf("%s.csv", fname)) %>%
            reformat_results()
          saveRDS(results, sprintf("%s.rds", fname))
        }
      } else {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = metric,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12)
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (x_model %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "modeling_choices",
                    sprintf("regression_sims_choices_min_samples%s_%s_%s_%s_errbars.pdf",
                            min_samples_per_leaf, y_model, x_model, metric)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = metric,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```


# MDI+ GLM/Metric Choices {.tabset .tabset-vmodern}

## Held-out Test Prediction Scores {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("#A5AE9E", "black", "brown")
show_methods <- c("RF", "RF-ridge", "RF-lasso")
method_labels <- c("RF", "RF+ridge", "RF+lasso")
alpha_values <- rep(1, length(method_labels))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")#, "splicing")
prediction_metrics <- c(
  "r2" #, "mean_absolute_error"
  # "rocauc", "prauc", "accuracy", "f1", "recall", "precision", "avg_precision", "logloss"
)

remove_x_axis_models <- c("enhancer")
keep_legend_x_models <- c("enhancer")
keep_legend_y_models <- c("linear")

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    plt_ls <- list()
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.regression_prediction_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results(prediction = TRUE)
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results(prediction = TRUE)
      saveRDS(results, sprintf("%s.rds", fname))
    }
    for (m in prediction_metrics) {
      metric_name <- dplyr::case_when(
        m == "r2" ~ "R-squared",
        m == "mae" ~ "Mean Absolute Error"
      )
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = TRUE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method"
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12)
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (x_model %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none")
          }
          plt_ls[[as.character(h)]] <- plt
        }
        plt <- patchwork::wrap_plots(plt_ls, nrow = 1)
        ggplot2::ggsave(
          file.path(figures_dir, "glm_metric_choices",
                    sprintf("regression_sims_glm_metric_choices_prediction_%s_%s_%s_errbars.pdf",
                            y_model, x_model, m)),
          plot = plt, units = "in", width = 14, height = height
        )
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method",
            title = sprintf("%s", sim_title)
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.3)
        chunk_idx <- chunk_idx + 1
      }
    }
  }
}

```

## Stability Scores - Regression {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("black", "black", "brown", "brown", "gray", 
                                 manual_color_palette_choices[method_labels_choices == "MDI"],
                                 manual_color_palette_choices[method_labels_choices == "MDI-oob"])
show_methods <- c("GMDI_ridge_r2_RF", "GMDI_ridge_neg_mae_RF", "GMDI_lasso_r2_RF", "GMDI_lasso_neg_mae_RF", "GMDI_ensemble_RF", "MDI_RF", "MDI-oob_RF")
method_labels <- c("MDI+ (ridge, r-squared)", "MDI+ (ridge, MAE)", "MDI+ (lasso, r-squared)", "MDI+ (lasso, MAE)", "MDI+ (ensemble)", "MDI", "MDI-oob")
manual_linetype_palette <- c(1, 2, 1, 2, 1, 1, 1)
alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
legend_position <- c(0.63, 0.4)

vary_param_name <- "heritability_sample_row_n"
# y_models <- c("linear", "lss_3m_2r", "hier_poly_3m_2r", "linear_lss_3m_2r")
# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
y_models <- c("linear", "hier_poly_3m_2r")
x_models <- c("enhancer", "ccle_rnaseq")
# stability_metrics <- c("tauAP", "RBO")
stability_metrics <- c("RBO")

remove_x_axis_models <- metric
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_metrics <- metric

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.regression_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    plt_ls <- list()
    for (m in c(metric, stability_metrics)) {
      metric_name <- dplyr::case_when(
        m == "rocauc" ~ "AUROC",
        m == "prauc" ~ "PRAUC",
        TRUE ~ m
      )
      if (params$for_paper) {
        for (h in heritabilities) {
          plt <- results %>%
            dplyr::filter(heritability == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              linetype_str = "method",
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = FALSE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method", linetype = "Method"
            ) +
            ggplot2::scale_linetype_manual(
              values = manual_linetype_palette, labels = method_labels
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12),
              legend.key.width = ggplot2::unit(1.9, "cm")
            )
          if (h != heritabilities[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (m %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == heritabilities[length(heritabilities)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models) &
                (metric %in% keep_legend_metrics))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
          }
          plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
        }
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "heritability_name",
            linetype_str = "method",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method", linetype = "Method",
            title = sprintf("%s", sim_title)
          ) +
          ggplot2::scale_linetype_manual(
            values = manual_linetype_palette, labels = method_labels
          ) +
          ggplot2::theme(
            legend.key.width = ggplot2::unit(1.9, "cm")
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.5)
        chunk_idx <- chunk_idx + 1
      }
    }
    if (params$for_paper) {
      nrows <- length(c(metric, stability_metrics))
      plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
      ggplot2::ggsave(
        file.path(figures_dir, "glm_metric_choices",
                  sprintf("regression_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
                          y_model, x_model)),
        plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
      )
    }
  }
}

```

## Stability Scores - Classification {.tabset .tabset-pills .tabset-square}

```{r eval = TRUE, results = "asis"}
manual_color_palette <- c("#9B5DFF", "#9B5DFF", "black",
                                 manual_color_palette_choices[method_labels_choices == "MDI"],
                                 manual_color_palette_choices[method_labels_choices == "MDI-oob"])
show_methods <- c("GMDI_logistic_ridge_logloss_RF", "GMDI_logistic_ridge_auroc_RF", "GMDI_ridge_r2_RF", "MDI_RF", "MDI-oob_RF")
method_labels <- c("MDI+ (logistic, log-loss)", "MDI+ (logistic, AUROC)", "MDI+ (ridge, r-squared)", "MDI", "MDI-oob")
manual_linetype_palette <- c(1, 2, 1, 1, 1)
alpha_values <- c(rep(1, length(method_labels) - 2), rep(0.4, 2))
legend_position <- c(0.63, 0.4)

vary_param_name <- "frac_label_corruption_sample_row_n"
# y_models <- c("logistic", "lss_3m_2r_logistic", "hier_poly_3m_2r_logistic", "linear_lss_3m_2r_logistic")
# x_models <- c("enhancer", "ccle_rnaseq", "juvenile", "splicing")
y_models <- c("logistic", "hier_poly_3m_2r_logistic")
x_models <- c("juvenile", "splicing")
# stability_metrics <- c("tauAP", "RBO")
stability_metrics <- "RBO"

remove_x_axis_models <- metric
keep_legend_x_models <- x_models
keep_legend_y_models <- y_models
keep_legend_metrics <- metric

for (y_model in y_models) {
  cat(sprintf("\n\n### %s {.tabset .tabset-pills .tabset-square}\n\n", y_model))
  for (x_model in x_models) {
    cat(sprintf("\n\n#### %s {.tabset .tabset-pills .tabset-circle}\n\n", x_model))
    sim_name <- sprintf("%s_%s_dgp", x_model, y_model)
    sim_title <- dplyr::case_when(
      x_model == "ccle_rnaseq" ~ "CCLE",
      x_model == "splicing" ~ "Splicing",
      x_model == "enhancer" ~ "Enhancer",
      x_model == "juvenile" ~ "Juvenile"
    )
    fname <- file.path(results_dir, 
                       sprintf("mdi_plus.glm_metric_choices_sims.classification_sims.%s", sim_name),
                       paste0("varying_", vary_param_name), 
                       paste0("seed", seed), "results")
    if (!file.exists(sprintf("%s.csv", fname))) {
      next
    }
    if (file.exists(sprintf("%s.rds", fname)) & params$use_cached) {
      results <- readRDS(sprintf("%s.rds", fname))
      if (length(setdiff(show_methods, unique(results$method))) > 0) {
        results <- data.table::fread(sprintf("%s.csv", fname)) %>%
          reformat_results()
        saveRDS(results, sprintf("%s.rds", fname))
      }
    } else {
      results <- data.table::fread(sprintf("%s.csv", fname)) %>%
        reformat_results()
      saveRDS(results, sprintf("%s.rds", fname))
    }
    plt_ls <- list()
    for (m in c(metric, stability_metrics)) {
      metric_name <- dplyr::case_when(
        m == "rocauc" ~ "AUROC",
        m == "prauc" ~ "PRAUC",
        TRUE ~ m
      )
      if (params$for_paper) {
        for (h in frac_label_corruptions) {
          plt <- results %>%
            dplyr::filter(frac_label_corruption_name == !!h) %>%
            plot_metrics(
              metric = m,
              x_str = "sample_row_n",
              facet_str = NULL,
              linetype_str = "method",
              point_size = point_size,
              line_size = line_size,
              errbar_width = errbar_width,
              manual_color_palette = manual_color_palette,
              show_methods = show_methods,
              method_labels = method_labels,
              alpha_values = alpha_values,
              legend_position = legend_position,
              custom_theme = custom_theme,
              inside_legend = FALSE
            ) +
            ggplot2::labs(
              x = "Sample Size", y = metric_name, 
              color = "Method", alpha = "Method", linetype = "Method"
            ) +
            ggplot2::scale_linetype_manual(
              values = manual_linetype_palette, labels = method_labels
            ) +
            ggplot2::theme(
              legend.text = ggplot2::element_text(size = 12),
              legend.key.width = ggplot2::unit(1.9, "cm")
            )
          if (h != frac_label_corruptions[1]) {
            plt <- plt + ggplot2::theme(axis.title.y = ggplot2::element_blank())
          }
          if (m %in% remove_x_axis_models) {
            height <- 2.72
            plt <- plt + ggplot2::theme(axis.title.x = ggplot2::element_blank())
          } else {
            height <- 3
          }
          if (!((h == frac_label_corruptions[length(frac_label_corruptions)]) & 
                (x_model %in% keep_legend_x_models) & 
                (y_model %in% keep_legend_y_models) &
                (metric %in% keep_legend_metrics))) {
            plt <- plt + ggplot2::guides(color = "none", alpha = "none", linetype = "none")
          }
          plt_ls[[sprintf("%s_%s", as.character(h), m)]] <- plt
        }
      } else {
        plt <- results %>%
          plot_metrics(
            metric = m,
            x_str = "sample_row_n",
            facet_str = "frac_label_corruption_name",
            linetype_str = "method",
            point_size = point_size,
            line_size = line_size,
            errbar_width = errbar_width,
            manual_color_palette = manual_color_palette,
            show_methods = show_methods,
            method_labels = method_labels,
            alpha_values = alpha_values,
            custom_theme = custom_theme
          ) +
          ggplot2::labs(
            x = "Sample Size", y = metric_name,
            color = "Method", alpha = "Method", linetype = "Method",
            title = sprintf("%s", sim_title)
          ) +
          ggplot2::scale_linetype_manual(
            values = manual_linetype_palette, labels = method_labels
          ) +
          ggplot2::theme(
            legend.key.width = ggplot2::unit(1.9, "cm")
          )
        vthemes::subchunkify(plot_fun(plt), i = chunk_idx,
                             fig_height = fig_height, fig_width = fig_width * 1.5)
        chunk_idx <- chunk_idx + 1
      }
    }
    if (params$for_paper) {
      nrows <- length(c(metric, stability_metrics))
      plt <- patchwork::wrap_plots(plt_ls, nrow = nrows, guides = "collect")
      ggplot2::ggsave(
        file.path(figures_dir, "glm_metric_choices",
                  sprintf("classification_sims_glm_metric_choices_stability_%s_%s_errbars.pdf",
                          y_model, x_model)),
        plot = plt, units = "in", width = 14 * 1.2, height = height * nrows
      )
    }
  }
}

```

+
+ + +
+
+ + + + + + + + + diff --git a/feature_importance/readme.md b/feature_importance/readme.md new file mode 100644 index 0000000..188a2f1 --- /dev/null +++ b/feature_importance/readme.md @@ -0,0 +1,148 @@ +# Feature Importance Simulations Pipeline + +This is a basic template to run extensive simulation studies for benchmarking a new feature importance method. The main python script is `01_run_importance_simulations.py` (which is largely based on `../01_fit_models.py`). To run these simulations, follow the three main steps below: + +1. Take the feature importance method of interest, and wrap it in a function that has the following structure: + - Inputs: + - `X`: A data frame of covariates/design matrix. + - `y`: Response vector. [Note that this argument is required even if it is not used by the feature importance method.] + - `fit`: A fitted estimator (e.g., a fitted RandomForestRegressor). + - Additional input arguments are allowed, but `X`, `y`, and `fit` are required at a minimum. + - Output: + - A data frame with at least the columns `var` and `importance`, containing the variable ID and the importance scores respectively. Additional columns are also permitted. + - For examples of this feature importance wrapper, see `scripts/competing_methods.py`. +2. Update configuration files (in `fi_config/`) to set the data-generating process(es), prediction models, and feature importance estimators to run in the simulation. See below for additional information and examples of these configuration files. +3. Run `01_run_importance_simulations.py` and pass in the appropriate commandline arguments. See below for additional information and examples. + - If `--create_rmd` is passed in as an argument in step 3, this will automatically generate an html document with some basic visualization summaries of the results. These results are rendered using R Markdown via `rmd/simulation_results.Rmd` and are saved in the results folder that was specified in step 3. + +Notes: + - To apply the feature importance method to real data (or any setting where the true support/signal features are unknown), one can use `02_run_importance_real_data.py` instead of `01_run_importance_simulations.py`. + - To evaluate the prediction accuracy of the model fits, see `03_run_prediction_simulations.py` and `04_run_prediction_real_data.py` for simulated and real data, respectively. + +Additional details for steps 2 and 3 are provided below. + + +## Creating the config files (step 2) + +For a starter template, see the `fi_config/test` folder. There are two necessary files: + +- `dgp.py`: Script specifying the data-generating process under study. The following variables must be provided: + - `X_DGP`: Function to generate X data. + - `X_PARAMS_DICT`: Dictionary of named arguments to pass into the `X_DGP` function. + - `Y_DGP`: Function to generate y data. + - `Y_PARAMS_DICT`: Dictionary of named arguments to pass into the `Y_DGP` function. + - `VARY_PARAM_NAME`: Name of argument (typically in `X_DGP` or `Y_DGP`) to vary across. Note that it is also possible to vary across an argument in an `ESTIMATOR` in very basic simulation setups. This can also be a vector of parameters to vary over in a grid. + - `VARY_PARAM_VALS`: Dictionary of named arguments for the `VARY_PARAM_NAME` to take on in the simulation experiment. Note that the value can be any python object, but make sure to keep the key simple for naming and plotting purposes. +- `models.py`: Script specifying the prediction methods and feature importance estimators under study. The following variables must be provided: + - `ESTIMATORS`: List of prediction methods to fit. Elements should be of class `ModelConfig`. + - Note that the class passed into `ModelConfig` should have a `fit` method (e.g., like sklearn models). + - Additional arguments to pass to the model class can be specified in a dictionary using the `other_params` arguments in `ModelConfig(). + - `FI_ESTIMATORS`: List of feature importance methods to fit. Elements should be of class `FIModelConfig`. + - Note that the function passed into `FIModelConfig` should take in the arguments `X`, `y`, `fit` at a minimum. For examples, see `scripts/competing_methods.py`. + - Additional arguments to pass to the feature importance function can be specified in a dictionary using the `other_params` argument in `FIModelConfig()`. + - Pair up a prediction model and feature importance estimator by using the same `model_type` ID in both. + - Note that by default, higher values from the feature importance method are assumed to indicate higher importance. If higher values indicate lower importance, set `ascending=False` in `FIModelConfig()`. + - If a feature importance estimator requires sample splitting (outside of the feature importance function call), use the `splitting_strategy` argument to specify the type of splitting strategy (e.g., `train-test`). + +For an example of the fi_config files used for real data case studies, see the `fi_config/mdi_plus/real_data_case_study/ccle_rnaseq_regression-/` folder. Like before, there are two necessary files: `dgp.py` and `models.py`. `models.py` follows the same structure and requirements as above. `dgp.py` is a script that defines two variables, `X_PATH` and `Y_PATH`, specifying the file paths of the covariate/feature data `X` and the response data `y`, respectively. + + +## Running the simulations (step 3) + +**For running simulations to evaluate feature importance rankings:** `01_run_importance_simulations.py` + +- Command Line Arguments: + - Simulation settings: + - `--nreps`: Number of replicates. + - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file. + - `--fi_model`: Name of feature importance estimator to run. Default (`None`) uses all feature importance estimators specified in `models.py` config file. + - `--config`: Name of fi_config folder (and title of the simulation experiment). + - `--omit_vars`: (Optional) Comma-separated string of feature indices to omit (as unobserved variables). More specifically, these features may be used in generating the response *y* but are omitted from the *X* used in training/evaluating the prediction model and feature importance estimator. + - Computational settings: + - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle). + - `--ignore_cache`: Whether or not to ignore cached results. + - `--verbose`: Whether or not to print messages. + - `--parallel`: Whether or not to run replicates in parallel. + - `--parallel_id`: ID for parallelization. + - `--n_cores`: Number of cores if running in parallel. Default uses all available cores. + - `--split_seed`: Seed for data splitting. + - `--results_path`: Path to save results. Default is `./results`. + - R Markdown output options: + - `--create_rmd`: Whether or not to output R Markdown-generated html file with summary of results. + - `--show_vars`: Max number of features to show in rejection probability plots in the R Markdown. Default (`None`) is to show all variables. +- Example usage in command line: +``` +python 01_run_importance_simulations.py --nreps 100 --config test --split_seed 331 --ignore_cache --create_rmd +``` + +**For running feature importance methods on real data:** `02_run_importance_real_data.py` + +- Command Line Arguments: + - Simulation settings: + - `--nreps`: Number of replicates (or times to run method on the given data). + - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file. + - `--fi_model`: Name of feature importance estimator to run. Default (`None`) uses all feature importance estimators specified in `models.py` config file. + - `--config`: Name of fi_config folder (and title of the simulation experiment). + - `--response_idx`: (Optional) Name of response column to use if response data *y* is a matrix or multi-task. If not provided, independent regression/classification problems are fitted for every column of *y* separately. If *y* is not a matrix, this argument should be ignored and is unused. + - Computational settings: + - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle). + - `--ignore_cache`: Whether or not to ignore cached results. + - `--verbose`: Whether or not to print messages. + - `--parallel`: Whether or not to run replicates in parallel. + - `--parallel_id`: ID for parallelization. + - `--n_cores`: Number of cores if running in parallel. Default uses all available cores. + - `--split_seed`: Seed for data splitting. + - `--results_path`: Path to save results. Default is `./results`. +- Example usage in command line: +``` +python 02_run_simulations.py --nreps 1 --config test --split_seed 331 --ignore_cache +``` + +**For running simulations to evaluate prediction accuracy of the model fits:** `03_run_prediction_simulations.py` + +- Command Line Arguments: + - Simulation settings: + - `--nreps`: Number of replicates. + - `--mode`: One of 'regression', 'binary_classification', or 'binary_classification'. + - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file. + - `--config`: Name of fi_config folder (and title of the simulation experiment). + - `--omit_vars`: (Optional) Comma-separated string of feature indices to omit (as unobserved variables). More specifically, these features may be used in generating the response *y* but are omitted from the *X* used in training/evaluating the prediction model and feature importance estimator. + - Computational settings: + - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle). + - `--ignore_cache`: Whether or not to ignore cached results. + - `splitting_strategy`: One of 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata', indicating how to split the data into training and test. + - `--verbose`: Whether or not to print messages. + - `--parallel`: Whether or not to run replicates in parallel. + - `--parallel_id`: ID for parallelization. + - `--n_cores`: Number of cores if running in parallel. Default uses all available cores. + - `--split_seed`: Seed for data splitting. + - `--results_path`: Path to save results. Default is `./results`. +- Example usage in command line: +``` +python 03_run_prediction_simulations.py --nreps 100 --config mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_linear_dgp --mode regression --split_seed 331 --ignore_cache --nosave_cols prediction_model +``` + +**For running prediction methods on real data:** `04_run_prediction_real_data.py` + +- Command Line Arguments: + - Simulation settings: + - `--nreps`: Number of replicates. + - `--mode`: One of 'regression', 'binary_classification', or 'binary_classification'. + - `--model`: Name of prediction model to run. Default (`None`) uses all models specified in `models.py` config file. + - `--config`: Name of fi_config folder (and title of the simulation experiment). + - `--response_idx`: (Optional) Name of response column to use if response data *y* is a matrix or multi-task. If not provided, independent regression/classification problems are fitted for every column of *y* separately. If *y* is not a matrix, this argument should be ignored and is unused. + - `--subsample_n`: (Optional) Integer indicating max number of samples to use in training prediction model. If None, no subsampling occurs. + - Computational settings: + - `--nosave_cols`: (Optional) Comma-separated string of column names to omit in the output file (to avoid potential errors when saving to pickle). + - `--ignore_cache`: Whether or not to ignore cached results. + - `splitting_strategy`: One of 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata', indicating how to split the data into training and test. + - `--verbose`: Whether or not to print messages. + - `--parallel`: Whether or not to run replicates in parallel. + - `--parallel_id`: ID for parallelization. + - `--n_cores`: Number of cores if running in parallel. Default uses all available cores. + - `--split_seed`: Seed for data splitting. + - `--results_path`: Path to save results. Default is `./results`. +- Example usage in command line: +``` +python 04_run_prediction_real_data.py --nreps 10 --config mdi_plus.prediction_sims.enhancer_classification- --mode binary_classification --split_seed 331 --ignore_cache --nosave_cols prediction_model +``` \ No newline at end of file diff --git a/feature_importance/rmd/simulation_results.Rmd b/feature_importance/rmd/simulation_results.Rmd new file mode 100644 index 0000000..8aef415 --- /dev/null +++ b/feature_importance/rmd/simulation_results.Rmd @@ -0,0 +1,717 @@ +--- +title: "Simulation Results" +author: "" +date: "`r format(Sys.time(), '%B %d, %Y')`" +output: vthemes::vmodern +params: + results_dir: + label: "Results directory" + value: "results/test" + vary_param_name: + label: "Name of varying parameter" + value: "sample_row_n" + seed: + label: "Seed" + value: 0 + keep_vars: + label: "Max variables to keep in rejection probability plots" + value: 100 + rm_fi: + label: "Feature importance methods to omit" + value: NULL + abridged: + label: "Abridged Document" + value: TRUE +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE) + +library(magrittr) +chunk_idx <- 1 + +# set parameters +results_dir <- params$results_dir +vary_param_name <- params$vary_param_name +if (is.null(vary_param_name)) { + vary_param_name_vec <- NULL +} else if (stringr::str_detect(vary_param_name, ";")) { + vary_param_name_vec <- stringr::str_split(vary_param_name, "; ")[[1]] + vary_param_name <- paste(vary_param_name_vec, collapse = "_") + if (length(vary_param_name_vec) != 2) { + stop("Rmarkdown report has not been configured to show results when >2 parameters are being varied.") + } +} else { + vary_param_name_vec <- NULL +} +seed <- params$seed +if (!is.null(params$keep_vars)) { + keep_vars <- 0:params$keep_vars +} else { + keep_vars <- params$keep_vars +} +abridged <- params$abridged +``` + +```{r helper-functions, echo = FALSE} +# reformat results +reformat_results <- function(results) { + if (!is.null(params$rm_fi)) { + results <- results %>% + dplyr::filter(!(fi %in% params$rm_fi)) + } + results_grouped <- results %>% + dplyr::group_by(index) %>% + tidyr::nest(fi_scores = var:(tidyselect::last_col())) %>% + dplyr::ungroup() %>% + dplyr::select(-index) %>% + # join fi+model to get method column + tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>% + dplyr::mutate( + # get rid of duplicate RF in r2f method name + method = ifelse(stringr::str_detect(method, "^r2f.*RF$"), + stringr::str_remove(method, "\\_RF$"), method) + ) + return(results_grouped) +} + +# plot metrics (mean value across repetitions with error bars) +plot_metrics <- function(results, vary_param_name, vary_param_name_vec, + show_errbars = TRUE) { + if (!is.null(vary_param_name_vec)) { + vary_param_name <- vary_param_name_vec + } + plt_df <- results %>% + dplyr::select(rep, method, + tidyselect::all_of(c(paste0(vary_param_name, "_name"), + metrics))) %>% + tidyr::pivot_longer( + cols = tidyselect::all_of(metrics), names_to = "metric" + ) %>% + dplyr::group_by( + dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))), + method, metric + ) %>% + dplyr::summarise(mean = mean(value), + sd = sd(value) / sqrt(dplyr::n()), + .groups = "keep") + + if (is.null(vary_param_name_vec)) { + if (length(unique(plt_df[[paste0(vary_param_name, "_name")]])) == 1) { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = method, y = mean) + + ggplot2::facet_wrap(~ metric, scales = "free_y", nrow = 1, ncol = 2) + + ggplot2::geom_point() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = "Method") + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = method, ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } else { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]], + y = mean, color = method) + + ggplot2::facet_wrap(~ metric, scales = "free_y", nrow = 1, ncol = 2) + + ggplot2::geom_point() + + ggplot2::geom_line() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = vary_param_name) + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]], + ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } + } else { + plt <- plt_df %>% + ggplot2::ggplot() + + ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]], + y = mean, color = method) + + ggplot2::facet_grid(metric ~ .data[[paste0(vary_param_name[1], "_name")]]) + + ggplot2::geom_point() + + ggplot2::geom_line() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = vary_param_name[2]) + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]], + ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } + return(plt) +} + +# plot restricted auroc/auprc +plot_restricted_metrics <- function(results, vary_param_name, + vary_param_name_vec, + quantiles = c(.1, .2, .3, .4), + show_errbars = TRUE) { + if (!is.null(vary_param_name_vec)) { + vary_param_name <- vary_param_name_vec + } + results <- results %>% + dplyr::select(rep, method, fi_scores, + tidyselect::all_of(c(paste0(vary_param_name, "_name")))) %>% + dplyr::mutate( + vars_ordered = purrr::map( + fi_scores, + function(fi_df) { + fi_df %>% + dplyr::filter(!is.na(cor_with_signal)) %>% + dplyr::arrange(-cor_with_signal) %>% + dplyr::pull(var) + } + ) + ) + + plt_df_ls <- list() + for (q in quantiles) { + plt_df_ls[[as.character(q)]] <- results %>% + dplyr::mutate( + restricted_metrics = purrr::map2_dfr( + fi_scores, vars_ordered, + function(fi_df, ignore_vars) { + ignore_vars <- ignore_vars[1:round(q * length(ignore_vars))] + auroc_r <- fi_df %>% + dplyr::filter(!(var %in% ignore_vars)) %>% + yardstick::roc_auc( + truth = factor(true_support, levels = c("1", "0")), importance, + event_level = "first" + ) %>% + dplyr::pull(.estimate) + auprc_r <- fi_df %>% + dplyr::filter(!(var %in% ignore_vars)) %>% + yardstick::pr_auc( + truth = factor(true_support, levels = c("1", "0")), importance, + event_level = "first" + ) %>% + dplyr::pull(.estimate) + return(data.frame(restricted_auroc = auroc_r, + restricted_auprc = auprc_r)) + } + ) + ) %>% + tidyr::unnest(restricted_metrics) %>% + tidyr::pivot_longer( + cols = c(restricted_auroc, restricted_auprc), names_to = "metric" + ) %>% + dplyr::group_by( + dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))), + method, metric + ) %>% + dplyr::summarise(mean = mean(value), + sd = sd(value) / sqrt(dplyr::n()), + .groups = "keep") %>% + dplyr::ungroup() + } + plt_df <- purrr::map_dfr(plt_df_ls, ~.x, .id = ".threshold") %>% + dplyr::mutate(.threshold = as.numeric(.threshold)) + + if (is.null(vary_param_name_vec)) { + if (length(unique(plt_df[[paste0(vary_param_name, "_name")]])) == 1) { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = method, y = mean) + + ggplot2::facet_grid(metric ~ .threshold, scales = "free_y") + + ggplot2::geom_point() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = "Method") + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = method, ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } else { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]], + y = mean, color = method) + + ggplot2::facet_grid(metric ~ .threshold, scales = "free_y") + + ggplot2::geom_point() + + ggplot2::geom_line() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = vary_param_name) + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = .data[[paste0(vary_param_name, "_name")]], + ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } + } else { + plt <- plt_df %>% + ggplot2::ggplot() + + ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]], + y = mean, color = method) + + ggplot2::facet_grid(metric + .threshold ~ .data[[paste0(vary_param_name[1], "_name")]]) + + ggplot2::geom_point() + + ggplot2::geom_line() + + vthemes::theme_vmodern() + + vthemes::scale_color_vmodern(discrete = TRUE) + + ggplot2::labs(x = vary_param_name[2]) + if (show_errbars) { + plt <- plt + + ggplot2::geom_errorbar( + mapping = ggplot2::aes(x = .data[[paste0(vary_param_name[2], "_name")]], + ymin = mean - sd, ymax = mean + sd), + width = 0 + ) + } + } + return(plt) +} + +# plot true positive rate across # positives +plot_tpr <- function(results, vary_param_name, vary_param_name_vec) { + if (!is.null(vary_param_name_vec)) { + vary_param_name <- vary_param_name_vec + } + if (is.null(results)) { + return(NULL) + } + + plt_df <- results %>% + dplyr::mutate( + fi_scores = mapply(name = fi, scores_df = fi_scores, + function(name, scores_df) { + scores_df <- scores_df %>% + dplyr::mutate( + ranking = rank(-importance, + ties.method = "random") + ) %>% + dplyr::arrange(ranking) %>% + dplyr::mutate( + .tp = cumsum(true_support) / sum(true_support) + ) + return(scores_df) + }, SIMPLIFY = FALSE) + ) %>% + tidyr::unnest(fi_scores) %>% + dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")), + rep, method, ranking, .tp) %>% + dplyr::group_by( + dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))), + method, ranking + ) %>% + dplyr::summarise(.tp = mean(.tp), .groups = "keep") + + if (is.null(vary_param_name_vec)) { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = ranking, y = .tp, color = method) + + ggplot2::geom_line() + + ggplot2::facet_wrap(reformulate(paste0(vary_param_name, "_name"))) + + ggplot2::labs(x = "Top n", y = "True Positive Rate", fill = "Method") + + vthemes::scale_color_vmodern(discrete = TRUE) + + vthemes::theme_vmodern() + } else { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = ranking, y = .tp, color = method) + + ggplot2::geom_line() + + ggplot2::facet_grid( + reformulate(paste0(vary_param_name[1], "_name"), + paste0(vary_param_name[2], "_name")) + ) + + ggplot2::labs(x = "Top n", y = "True Positives Rate", fill = "Method") + + vthemes::scale_color_vmodern(discrete = TRUE) + + vthemes::theme_vmodern() + } + return(plt) +} + +# plot feature importances +plot_feature_importance <- function(results, vary_param_name, + vary_param_name_vec, + keep_vars = NULL, + plot_type = c("boxplot", "bar")) { + if (!is.null(vary_param_name_vec)) { + vary_param_name <- vary_param_name_vec + } + + plot_type <- match.arg(plot_type) + plt_df <- results %>% + tidyr::unnest(fi_scores) + if (plot_type == "bar") { + plt_df <- plt_df %>% + dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")), + rep, method, var, importance) %>% + dplyr::group_by( + dplyr::across(tidyselect::all_of(paste0(vary_param_name, "_name"))), + method, var + ) %>% + dplyr::summarise(mean_fi = mean(importance), .groups = "keep") + } + if (!is.null(keep_vars)) { + plt_df <- plt_df %>% + dplyr::filter(var %in% keep_vars) + } + plt_ls <- list() + if (is.null(vary_param_name_vec)) { + for (val in unique(plt_df[[paste0(vary_param_name, "_name")]])) { + if (plot_type == "bar") { + plt <- plt_df %>% + dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = mean_fi) + + ggplot2::geom_bar(stat = "identity", color = "grey98", + fill = "#00C5FF") + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val), + x = "Feature", y = "Mean Importance / Significance") + + vthemes::theme_vmodern() + } else if (plot_type == "boxplot") { + plt <- plt_df %>% + dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = importance, group = var) + + ggplot2::geom_boxplot() + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val), + x = "Feature", y = "Importance / Significance") + + vthemes::theme_vmodern() + } + plt_ls[[as.character(val)]] <- plt + } + } else { + for (val1 in unique(plt_df[[paste0(vary_param_name[1], "_name")]])) { + plt_ls[[as.character(val1)]] <- list() + for (val2 in unique(plt_df[[paste0(vary_param_name[2], "_name")]])) { + if (plot_type == "bar") { + plt <- plt_df %>% + dplyr::filter( + .data[[paste0(vary_param_name[1], "_name")]] == !!val1, + .data[[paste0(vary_param_name[2], "_name")]] == !!val2 + ) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = mean_fi) + + ggplot2::geom_bar(stat = "identity", color = "grey98", fill = "#00C5FF") + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::labs(title = sprintf("%s = %s; %s = %s", + vary_param_name[1], val1, + vary_param_name[2], val2), + x = "Feature", y = "Mean Importance / Significance") + + vthemes::theme_vmodern() + } else if (plot_type == "boxplot") { + plt <- plt_df %>% + dplyr::filter( + .data[[paste0(vary_param_name[1], "_name")]] == !!val1, + .data[[paste0(vary_param_name[2], "_name")]] == !!val2 + ) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = importance, group = var) + + ggplot2::geom_boxplot() + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::labs(title = sprintf("%s = %s; %s = %s", + vary_param_name[1], val1, + vary_param_name[2], val2), + x = "Feature", y = "Importance / Significance") + + vthemes::theme_vmodern() + } + plt_ls[[as.character(val1)]][[as.character(val2)]] <- plt + } + } + } + + return(plt_ls) +} + +# plot ranking heatmap +plot_ranking_heatmap <- function(results, vary_param_name, vary_param_name_vec, + keep_vars = NULL) { + if (!is.null(vary_param_name_vec)) { + vary_param_name <- vary_param_name_vec + } + + plt_df <- results %>% + dplyr::mutate( + fi_scores = mapply(name = fi, scores_df = fi_scores, + function(name, scores_df) { + scores_df <- scores_df %>% + dplyr::mutate(ranking = rank(-importance)) + return(scores_df) + }, SIMPLIFY = FALSE) + ) %>% + tidyr::unnest(fi_scores) %>% + dplyr::select(tidyselect::all_of(paste0(vary_param_name, "_name")), + rep, method, var, ranking, importance) + + if (!is.null(keep_vars)) { + plt_df <- plt_df %>% + dplyr::filter(var %in% keep_vars) + } + plt_ls <- list() + if (is.null(vary_param_name_vec)) { + for (val in unique(plt_df[[paste0(vary_param_name, "_name")]])) { + plt <- plt_df %>% + dplyr::filter(.data[[paste0(vary_param_name, "_name")]] == !!val) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = rep, fill = ranking, text = importance) + + ggplot2::geom_tile() + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::coord_cartesian(expand = FALSE) + + ggplot2::labs(title = sprintf("%s = %s", vary_param_name, val), + x = "Feature", y = "Replicate", fill = "Ranking") + + vthemes::scale_fill_vmodern() + + vthemes::theme_vmodern() + plt_ls[[as.character(val)]] <- plt + } + } else { + for (val1 in unique(plt_df[[paste0(vary_param_name[1], "_name")]])) { + plt_ls[[as.character(val1)]] <- list() + for (val2 in unique(plt_df[[paste0(vary_param_name[2], "_name")]])) { + plt <- plt_df %>% + dplyr::filter( + .data[[paste0(vary_param_name[1], "_name")]] == !!val1, + .data[[paste0(vary_param_name[2], "_name")]] == !!val2 + ) %>% + ggplot2::ggplot() + + ggplot2::aes(x = var, y = rep, fill = ranking, text = importance) + + ggplot2::geom_tile() + + ggplot2::facet_wrap(~ method, scales = "free", + ncol = 2, + nrow = ceiling(length(unique(plt_df$method)) / 2)) + + ggplot2::coord_cartesian(expand = FALSE) + + ggplot2::labs(title = sprintf("%s = %s; %s = %s", + vary_param_name[1], val1, + vary_param_name[2], val2), + x = "Feature", y = "Replicate", fill = "Ranking") + + vthemes::scale_fill_vmodern() + + vthemes::theme_vmodern() + plt_ls[[as.character(val1)]][[as.character(val2)]] <- plt + } + } + } + return(plt_ls) +} + +# view results in Rmarkdown +# notes: need to set 'results = "asis"' in the code chunk header +view_results <- function(results_ls, metrics_plt_ls, rmetrics_plt_ls, + tpr_plt_ls, fi_bar_plt_ls, fi_box_plt_ls, + heatmap_plt_ls, vary_param_name_vec, abridged, + interactive = TRUE) { + cat(sprintf("\n\n# %s {.tabset .tabset-vmodern}\n\n", + basename(results_dir))) + if (is.null(vary_param_name_vec)) { + height <- 4 + tpr_height <- height + } else { + height <- 8 + tpr_height <- 4 * length(unique(results_ls[[paste0(vary_param_name_vec[2], + "_name")]])) + } + + for (sim_name in names(results_ls)) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + cat(sprintf("\n\n## %s {.tabset .tabset-pills}\n\n", sim_name)) + + if (!abridged) { + cat(sprintf("\n\n### Tables\n\n")) + vthemes::subchunkify(vthemes::pretty_DT(results_ls[[sim_name]]), + i = chunk_idx) + chunk_idx <<- chunk_idx + 1 + } + + cat(sprintf("\n\n### Plots {.tabset .tabset-pills .tabset-square}\n\n")) + if (interactive) { + vthemes::subchunkify(plotly::ggplotly(metrics_plt_ls[[sim_name]]), + i = chunk_idx, other_args = "out.width = '100%'", + fig_height = height, + add_class = "panel panel-default padded-panel") + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(rmetrics_plt_ls[[sim_name]]), + i = chunk_idx, other_args = "out.width = '100%'", + fig_height = height, + add_class = "panel panel-default padded-panel") + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(tpr_plt_ls[[sim_name]]), + i = chunk_idx, other_args = "out.width = '100%'", + fig_height = height, + add_class = "panel panel-default padded-panel") + } else { + vthemes::subchunkify(metrics_plt_ls[[sim_name]], i = chunk_idx, + fig_height = height) + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(rmetrics_plt_ls[[sim_name]], i = chunk_idx, + fig_height = height) + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(tpr_plt_ls[[sim_name]], i = chunk_idx, + fig_height = height) + } + chunk_idx <<- chunk_idx + 1 + + if (is.null(vary_param_name_vec)) { + for (param_val in names(heatmap_plt_ls[[sim_name]])) { + cat(sprintf("\n\n#### %s = %s\n\n\n", vary_param_name, param_val)) + if (interactive) { + vthemes::subchunkify(plotly::ggplotly(heatmap_plt_ls[[sim_name]][[param_val]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(fi_box_plt_ls[[sim_name]][[param_val]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + if (!abridged) { + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(fi_bar_plt_ls[[sim_name]][[param_val]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + } + } else { + vthemes::subchunkify(heatmap_plt_ls[[sim_name]][[param_val]], i = chunk_idx) + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(fi_box_plt_ls[[sim_name]][[param_val]], i = chunk_idx) + if (!abridged) { + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(fi_bar_plt_ls[[sim_name]][[param_val]], i = chunk_idx) + } + } + chunk_idx <<- chunk_idx + 1 + } + } else { + for (param_val1 in names(heatmap_plt_ls[[sim_name]])) { + cat(sprintf("\n\n#### %s = %s {.tabset .tabset-pills .tabset-circle}\n\n\n", + vary_param_name_vec[1], param_val1)) + for (param_val2 in names(heatmap_plt_ls[[sim_name]][[param_val1]])) { + cat(sprintf("\n\n##### %s = %s\n\n\n", + vary_param_name_vec[2], param_val2)) + if (interactive) { + vthemes::subchunkify(plotly::ggplotly(heatmap_plt_ls[[sim_name]][[param_val1]][[param_val2]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(fi_box_plt_ls[[sim_name]][[param_val1]][[param_val2]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + if (!abridged) { + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(plotly::ggplotly(fi_bar_plt_ls[[sim_name]][[param_val1]][[param_val2]]), + i = chunk_idx, other_args = "out.width = '100%'", + add_class = "panel panel-default padded-panel") + } + } else { + vthemes::subchunkify(heatmap_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx) + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(fi_box_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx) + if (!abridged) { + chunk_idx <<- chunk_idx + 1 + vthemes::subchunkify(fi_bar_plt_ls[[sim_name]][[param_val1]][[param_val2]], i = chunk_idx) + } + } + chunk_idx <<- chunk_idx + 1 + } + } + } + } +} +``` + +```{r} +# read in results +results_ls <- list() +for (results_subdir in list.dirs(results_dir, full.names = T, recursive = F)) { + if (!is.null(vary_param_name)) { + if (!(results_subdir %in% file.path(results_dir, + paste0("varying_", vary_param_name)))) { + next + } + } + fname <- file.path(results_subdir, paste0("seed", seed), "results.csv") + if (file.exists(fname)) { + results_ls[[basename(results_subdir)]] <- data.table::fread(fname) %>% + reformat_results() + } +} + +# plot evaluation metrics +metrics_plt_ls <- list() +for (sim_name in names(results_ls)) { + metrics <- intersect(colnames(results_ls[[sim_name]]), c("rocauc", "prauc")) + if (length(metrics) > 0) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + metrics_plt_ls[[sim_name]] <- plot_metrics( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec, + ) + } else { + metrics_plt_ls[[sim_name]] <- NULL + } +} + +# plot restricted evaluation metrics +rmetrics_plt_ls <- list() +for (sim_name in names(results_ls)) { + if (length(metrics) > 0) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + rmetrics_plt_ls[[sim_name]] <- plot_restricted_metrics( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec, + ) + } else { + rmetrics_plt_ls[[sim_name]] <- NULL + } +} + +# plot tpr +tpr_plt_ls <- list() +for (sim_name in names(results_ls)) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + tpr_plt_ls[[sim_name]] <- plot_tpr( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec + ) +} + +# plot feature importances +fi_box_plt_ls <- list() +fi_bar_plt_ls <- list() +for (sim_name in names(results_ls)) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + fi_box_plt_ls[[sim_name]] <- plot_feature_importance( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars, plot_type = "boxplot" + ) + fi_bar_plt_ls[[sim_name]] <- plot_feature_importance( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars, plot_type = "bar" + ) +} + +# plot heatmap +heatmap_plt_ls <- list() +for (sim_name in names(results_ls)) { + vary_param_name <- stringr::str_remove(sim_name, "^varying\\_") + heatmap_plt_ls[[sim_name]] <- plot_ranking_heatmap( + results_ls[[sim_name]], vary_param_name, vary_param_name_vec, keep_vars + ) +} +``` + +```{r results = "asis"} +# display plots nicely in knitted html document +view_results(results_ls, metrics_plt_ls, rmetrics_plt_ls, tpr_plt_ls, + fi_bar_plt_ls, fi_box_plt_ls, heatmap_plt_ls, + vary_param_name_vec, abridged) +``` + diff --git a/feature_importance/savio/run_simulations_mdi_plus.sh b/feature_importance/savio/run_simulations_mdi_plus.sh new file mode 100644 index 0000000..8bb29e6 --- /dev/null +++ b/feature_importance/savio/run_simulations_mdi_plus.sh @@ -0,0 +1,230 @@ +sims=( + # Regression: ccle_rnaseq + "mdi_plus.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_linear_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_lss_3m_2r_dgp" + # Regression: enhancer + "mdi_plus.regression_sims.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.enhancer_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.enhancer_linear_dgp" + "mdi_plus.regression_sims.enhancer_lss_3m_2r_dgp" + # Regression: juvenile + "mdi_plus.regression_sims.juvenile_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.juvenile_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.juvenile_linear_dgp" + "mdi_plus.regression_sims.juvenile_lss_3m_2r_dgp" + # Regression: splicing + "mdi_plus.regression_sims.splicing_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.splicing_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.splicing_linear_dgp" + "mdi_plus.regression_sims.splicing_lss_3m_2r_dgp" + + # Classification: ccle_rnaseq + "mdi_plus.classification_sims.ccle_rnaseq_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.ccle_rnaseq_logistic_dgp" + "mdi_plus.classification_sims.ccle_rnaseq_linear_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.ccle_rnaseq_hier_poly_3m_2r_logistic_dgp" + # Classification: enhancer + "mdi_plus.classification_sims.enhancer_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.enhancer_logistic_dgp" + "mdi_plus.classification_sims.enhancer_linear_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.enhancer_hier_poly_3m_2r_logistic_dgp" + # Classification: juvenile + "mdi_plus.classification_sims.juvenile_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.juvenile_logistic_dgp" + "mdi_plus.classification_sims.juvenile_linear_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.juvenile_hier_poly_3m_2r_logistic_dgp" + # Classification: splicing + "mdi_plus.classification_sims.splicing_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.splicing_logistic_dgp" + "mdi_plus.classification_sims.splicing_linear_lss_3m_2r_logistic_dgp" + "mdi_plus.classification_sims.splicing_hier_poly_3m_2r_logistic_dgp" + + # Robust: enhancer + "mdi_plus.robust_sims.enhancer_linear_10MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_linear_25MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_lss_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_lss_3m_2r_25MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_linear_lss_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_linear_lss_3m_2r_25MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_hier_poly_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.enhancer_hier_poly_3m_2r_25MS_robust_dgp" + # Robust: ccle_rnaseq + "mdi_plus.robust_sims.ccle_rnaseq_linear_10MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_linear_25MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_lss_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_lss_3m_2r_25MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_linear_lss_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_linear_lss_3m_2r_25MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_hier_poly_3m_2r_10MS_robust_dgp" + "mdi_plus.robust_sims.ccle_rnaseq_hier_poly_3m_2r_25MS_robust_dgp" + + # MDI bias simulations + "mdi_plus.mdi_bias_sims.correlation_sims.normal_block_cor_partial_linear_lss_dgp" + "mdi_plus.mdi_bias_sims.entropy_sims.linear_dgp" + "mdi_plus.mdi_bias_sims.entropy_sims.logistic_dgp" + + # Regression (varying number of features): ccle_rnaseq + "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_linear_lss_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_linear_dgp" + "mdi_plus.other_regression_sims.varying_p.ccle_rnaseq_lss_3m_2r_dgp" + # Regression (varying sparsity level): juvenile + "mdi_plus.other_regression_sims.varying_sparsity.juvenile_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.juvenile_linear_lss_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.juvenile_linear_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.juvenile_lss_3m_2r_dgp" + # Regression (varying sparsity level): splicing + "mdi_plus.other_regression_sims.varying_sparsity.splicing_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.splicing_linear_lss_3m_2r_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.splicing_linear_dgp" + "mdi_plus.other_regression_sims.varying_sparsity.splicing_lss_3m_2r_dgp" + + # MDI+ Modeling Choices: Enhancer (min_samples_per_leaf = 5) + "mdi_plus.other_regression_sims.modeling_choices_min_samples5.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.modeling_choices_min_samples5.enhancer_linear_dgp" + # MDI+ Modeling Choices: CCLE (min_samples_per_leaf = 5) + "mdi_plus.other_regression_sims.modeling_choices_min_samples5.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.modeling_choices_min_samples5.ccle_rnaseq_linear_dgp" + # MDI+ Modeling Choices: Enhancer (min_samples_per_leaf = 1) + "mdi_plus.other_regression_sims.modeling_choices_min_samples1.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.modeling_choices_min_samples1.enhancer_linear_dgp" + # MDI+ Modeling Choices: CCLE (min_samples_per_leaf = 1) + "mdi_plus.other_regression_sims.modeling_choices_min_samples1.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.other_regression_sims.modeling_choices_min_samples1.ccle_rnaseq_linear_dgp" + + # MDI+ GLM and Metric Choices: Regression + "mdi_plus.glm_metric_choices_sims.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.glm_metric_choices_sims.regression_sims.ccle_rnaseq_linear_dgp" + "mdi_plus.glm_metric_choices_sims.regression_sims.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.glm_metric_choices_sims.regression_sims.enhancer_linear_dgp" + # MDI+ GLM and Metric Choices: Classification + "mdi_plus.glm_metric_choices_sims.classification_sims.juvenile_logistic_dgp" + "mdi_plus.glm_metric_choices_sims.classification_sims.juvenile_hier_poly_3m_2r_logistic_dgp" + "mdi_plus.glm_metric_choices_sims.classification_sims.splicing_logistic_dgp" + "mdi_plus.glm_metric_choices_sims.classification_sims.splicing_hier_poly_3m_2r_logistic_dgp" +) + +for sim in "${sims[@]}" +do + sbatch --job-name=${sim} submit_simulation_job.sh ${sim} +done + + +## Misspecified model simulations +misspecifiedsims=( + # Misspecified Regression: ccle_rnaseq + "mdi_plus.regression_sims.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_linear_dgp" + "mdi_plus.regression_sims.ccle_rnaseq_lss_3m_2r_dgp" + # Misspecified Regression: enhancer + "mdi_plus.regression_sims.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.enhancer_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.enhancer_linear_dgp" + "mdi_plus.regression_sims.enhancer_lss_3m_2r_dgp" + # Misspecified Regression: juvenile + "mdi_plus.regression_sims.juvenile_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.juvenile_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.juvenile_linear_dgp" + "mdi_plus.regression_sims.juvenile_lss_3m_2r_dgp" + # Misspecified Regression: splicing + "mdi_plus.regression_sims.splicing_hier_poly_3m_2r_dgp" + "mdi_plus.regression_sims.splicing_linear_lss_3m_2r_dgp" + "mdi_plus.regression_sims.splicing_linear_dgp" + "mdi_plus.regression_sims.splicing_lss_3m_2r_dgp" +) + +for sim in "${misspecifiedsims[@]}" +do + sbatch --job-name=${sim}_omitted_vars submit_simulation_job_omitted_vars.sh ${sim} +done + + +## Real-data Case Study +# CCLE RNASeq +drugs=( + "17-AAG" + "AEW541" + "AZD0530" + "AZD6244" + "Erlotinib" + "Irinotecan" + "L-685458" + "LBW242" + "Lapatinib" + "Nilotinib" + "Nutlin-3" + "PD-0325901" + "PD-0332991" + "PF2341066" + "PHA-665752" + "PLX4720" + "Paclitaxel" + "Panobinostat" + "RAF265" + "Sorafenib" + "TAE684" + "TKI258" + "Topotecan" + "ZD-6474" +) + +sim="mdi_plus.real_data_case_study.ccle_rnaseq_regression-" +for drug in "${drugs[@]}" +do + sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_multitask.sh ${sim} "regression" ${drug} +done + +sim="mdi_plus.real_data_case_study_no_data_split.ccle_rnaseq_regression-" +for drug in "${drugs[@]}" +do + sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_multitask.sh ${sim} "regression" ${drug} +done + +# TCGA BRCA +sim="mdi_plus.real_data_case_study.tcga_brca_classification-" +sbatch --job-name=${sim} submit_simulation_job_real_data.sh ${sim} "multiclass_classification" + +sim="mdi_plus.real_data_case_study_no_data_split.tcga_brca_classification-" +sbatch --job-name=${sim} submit_simulation_job_real_data.sh ${sim} "multiclass_classification" + + +## Prediction Simulations + +# Real Data: Binary classification +sims=( + "mdi_plus.prediction_sims.enhancer_classification-" + "mdi_plus.prediction_sims.juvenile_classification-" + "mdi_plus.prediction_sims.splicing_classification-" +) + +for sim in "${sims[@]}" +do + sbatch --job-name=${sim} submit_simulation_job_real_data_prediction.sh ${sim} "binary_classification" +done + +# Real Data: Multi-class classification +sim="mdi_plus.prediction_sims.tcga_brca_classification-" +sbatch --job-name=${sim} submit_simulation_job_real_data_prediction.sh ${sim} "multiclass_classification" + +# Real Data: Regression +sim="mdi_plus.prediction_sims.ccle_rnaseq_regression-" +for drug in "${drugs[@]}" +do + sbatch --job-name=${sim}_${drug} submit_simulation_job_real_data_prediction_multitask.sh ${sim} ${drug} "regression" +done + +# MDI+ GLM and Metric Choices: Regression +sims=( + "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.ccle_rnaseq_hier_poly_3m_2r_dgp" + "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.ccle_rnaseq_linear_dgp" + "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_hier_poly_3m_2r_dgp" + "mdi_plus.glm_metric_choices_sims.regression_prediction_sims.enhancer_linear_dgp" +) + +for sim in "${sims[@]}" +do + sbatch --job-name=${sim} submit_simulation_job_prediction.sh ${sim} "regression" +done \ No newline at end of file diff --git a/feature_importance/savio/submit_simulation_job.sh b/feature_importance/savio/submit_simulation_job.sh new file mode 100644 index 0000000..5421aed --- /dev/null +++ b/feature_importance/savio/submit_simulation_job.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=12:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../01_run_importance_simulations.py --nreps 50 --config ${1} --split_seed 12345 --parallel --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_omitted_vars.sh b/feature_importance/savio/submit_simulation_job_omitted_vars.sh new file mode 100644 index 0000000..737fd8d --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_omitted_vars.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=12:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../01_run_importance_simulations.py --nreps 50 --config ${1} --split_seed 12345 --omit_vars 0,1 --parallel --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_prediction.sh b/feature_importance/savio/submit_simulation_job_prediction.sh new file mode 100644 index 0000000..f3c60ad --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_prediction.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=12:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../03_run_prediction_simulations.py --nreps 50 --config ${1} --split_seed 12345 --parallel --mode ${2} --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_real_data.sh b/feature_importance/savio/submit_simulation_job_real_data.sh new file mode 100644 index 0000000..1fdd2b8 --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_real_data.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=48:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../02_run_importance_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_real_data_multitask.sh b/feature_importance/savio/submit_simulation_job_real_data_multitask.sh new file mode 100644 index 0000000..1a4ec9f --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_real_data_multitask.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=48:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../02_run_importance_real_data.py --nreps 32 --config ${1} --response_idx ${2} --split_seed 12345 --parallel --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_real_data_prediction.sh b/feature_importance/savio/submit_simulation_job_real_data_prediction.sh new file mode 100644 index 0000000..b610ec5 --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_real_data_prediction.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=24:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../04_run_prediction_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --mode ${2} --subsample_n 10000 --nosave_cols "prediction_model" diff --git a/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh b/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh new file mode 100644 index 0000000..5110468 --- /dev/null +++ b/feature_importance/savio/submit_simulation_job_real_data_prediction_multitask.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --account=co_stat +#SBATCH --partition=savio +#SBATCH --time=24:00:00 +# +#SBATCH --nodes=1 + +module load python/3.7 +module load r + +source activate r2f + +python ../04_run_prediction_real_data.py --nreps 32 --config ${1} --split_seed 12345 --parallel --response_idx ${2} --mode ${3} --nosave_cols "prediction_model" diff --git a/feature_importance/scripts/__init__.py b/feature_importance/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feature_importance/scripts/competing_methods.py b/feature_importance/scripts/competing_methods.py new file mode 100644 index 0000000..7a96079 --- /dev/null +++ b/feature_importance/scripts/competing_methods.py @@ -0,0 +1,236 @@ +import os +import sys +import pandas as pd +import numpy as np +import sklearn.base +from sklearn.base import RegressorMixin, ClassifierMixin +from functools import reduce + +import shap +from imodels.importance.rf_plus import RandomForestPlusRegressor, RandomForestPlusClassifier +from feature_importance.scripts.mdi_oob import MDI_OOB +from feature_importance.scripts.mda import MDA + + +def tree_mdi_plus_ensemble(X, y, fit, scoring_fns="auto", **kwargs): + """ + Wrapper around MDI+ object to get feature importance scores + + :param X: ndarray of shape (n_samples, n_features) + The covariate matrix. If a pd.DataFrame object is supplied, then + the column names are used in the output + :param y: ndarray of shape (n_samples, n_targets) + The observed responses. + :param rf_model: scikit-learn random forest object or None + The RF model to be used for interpretation. If None, then a new + RandomForestRegressor or RandomForestClassifier is instantiated. + :param kwargs: additional arguments to pass to + RandomForestPlusRegressor or RandomForestPlusClassifier class. + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDI+ score + """ + + if isinstance(fit, RegressorMixin): + RFPlus = RandomForestPlusRegressor + elif isinstance(fit, ClassifierMixin): + RFPlus = RandomForestPlusClassifier + else: + raise ValueError("Unknown task.") + + mdi_plus_scores_dict = {} + for rf_plus_name, rf_plus_args in kwargs.items(): + rf_plus_model = RFPlus(rf_model=fit, **rf_plus_args) + rf_plus_model.fit(X, y) + try: + mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns) + except ValueError as e: + if str(e) == 'Transformer representation was empty for all trees.': + mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance']) + if isinstance(X, pd.DataFrame): + mdi_plus_scores.index = X.columns + mdi_plus_scores.index.name = 'var' + mdi_plus_scores.reset_index(inplace=True) + else: + raise + for col in mdi_plus_scores.columns: + if col != "var": + mdi_plus_scores = mdi_plus_scores.rename(columns={col: col + "_" + rf_plus_name}) + mdi_plus_scores_dict[rf_plus_name] = mdi_plus_scores + + mdi_plus_scores_df = pd.concat([df.set_index('var') for df in mdi_plus_scores_dict.values()], axis=1) + mdi_plus_ranks_df = mdi_plus_scores_df.rank(ascending=False).median(axis=1) + mdi_plus_ranks_df = pd.DataFrame(mdi_plus_ranks_df, columns=["importance"]).reset_index() + + return mdi_plus_ranks_df + + +def tree_mdi_plus(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs): + """ + Wrapper around MDI+ object to get feature importance scores + + :param X: ndarray of shape (n_samples, n_features) + The covariate matrix. If a pd.DataFrame object is supplied, then + the column names are used in the output + :param y: ndarray of shape (n_samples, n_targets) + The observed responses. + :param rf_model: scikit-learn random forest object or None + The RF model to be used for interpretation. If None, then a new + RandomForestRegressor or RandomForestClassifier is instantiated. + :param kwargs: additional arguments to pass to + RandomForestPlusRegressor or RandomForestPlusClassifier class. + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDI+ score + """ + + if isinstance(fit, RegressorMixin): + RFPlus = RandomForestPlusRegressor + elif isinstance(fit, ClassifierMixin): + RFPlus = RandomForestPlusClassifier + else: + raise ValueError("Unknown task.") + rf_plus_model = RFPlus(rf_model=fit, **kwargs) + rf_plus_model.fit(X, y) + try: + mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns) + if return_stability_scores: + stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25) + except ValueError as e: + if str(e) == 'Transformer representation was empty for all trees.': + mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance']) + if isinstance(X, pd.DataFrame): + mdi_plus_scores.index = X.columns + mdi_plus_scores.index.name = 'var' + mdi_plus_scores.reset_index(inplace=True) + stability_scores = None + else: + raise + mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_ + if return_stability_scores: + mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1) + + return mdi_plus_scores + + +def tree_mdi(X, y, fit, include_num_splits=False): + """ + Extract MDI values for a given tree + OR + Average MDI values for a given random forest + :param X: design matrix + :param y: response + :param fit: fitted model of interest + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDI or avg MDI + """ + av_splits = get_num_splits(X, y, fit) + results = fit.feature_importances_ + results = pd.DataFrame(data=results, columns=['importance']) + + # Use column names from dataframe if possible + if isinstance(X, pd.DataFrame): + results.index = X.columns + results.index.name = 'var' + results.reset_index(inplace=True) + + if include_num_splits: + results['av_splits'] = av_splits + + return results + + +def tree_mdi_OOB(X, y, fit, type='oob', + normalized=False, balanced=False, demean=False, normal_fX=False): + """ + Compute MDI-oob feature importance for a given random forest + :param X: design matrix + :param y: response + :param fit: fitted model of interest + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDI-oob + """ + reshaped_y = y.reshape((len(y), 1)) + results = MDI_OOB(fit, X, reshaped_y, type=type, normalized=normalized, balanced=balanced, + demean=demean, normal_fX=normal_fX)[0] + results = pd.DataFrame(data=results, columns=['importance']) + if isinstance(X, pd.DataFrame): + results.index = X.columns + results.index.name = 'var' + results.reset_index(inplace=True) + + return results + + +def tree_shap(X, y, fit): + """ + Compute average treeshap value across observations + :param X: design matrix + :param y: response + :param fit: fitted model of interest (tree-based) + :return: dataframe - [Var, Importance] + Var: variable name + Importance: average absolute shap value + """ + explainer = shap.TreeExplainer(fit) + shap_values = explainer.shap_values(X, check_additivity=False) + if sklearn.base.is_classifier(fit): + def add_abs(a, b): + return abs(a) + abs(b) + results = reduce(add_abs, shap_values) + else: + results = abs(shap_values) + results = results.mean(axis=0) + results = pd.DataFrame(data=results, columns=['importance']) + # Use column names from dataframe if possible + if isinstance(X, pd.DataFrame): + results.index = X.columns + results.index.name = 'var' + results.reset_index(inplace=True) + + return results + + +def tree_mda(X, y, fit, type="oob", n_repeats=10, metric="auto"): + """ + Compute MDA importance for a given random forest + :param X: design matrix + :param y: response + :param fit: fitted model of interest + :param type: "oob" or "train" + :param n_repeats: number of permutations + :param metric: metric for computation MDA/permutation importance + :return: dataframe - [Var, Importance] + Var: variable name + Importance: MDA + """ + if metric == "auto": + if isinstance(y[0], str): + metric = "accuracy" + else: + metric = "mse" + + results, _ = MDA(fit, X, y[:, np.newaxis], type=type, n_trials=n_repeats, metric=metric) + results = pd.DataFrame(data=results, columns=['importance']) + + if isinstance(X, pd.DataFrame): + results.index = X.columns + results.index.name = 'var' + results.reset_index(inplace=True) + + return results + + +def get_num_splits(X, y, fit): + """ + Gets number of splits per feature in a fitted RF + """ + num_splits_feature = np.zeros(X.shape[1]) + for tree in fit.estimators_: + tree_features = tree.tree_.feature + for i in range(X.shape[1]): + num_splits_feature[i] += np.count_nonzero(tree_features == i) + num_splits_feature/len(fit.estimators_) + return num_splits_feature diff --git a/feature_importance/scripts/data_cleaning_mdi_plus.R b/feature_importance/scripts/data_cleaning_mdi_plus.R new file mode 100644 index 0000000..a932cbf --- /dev/null +++ b/feature_importance/scripts/data_cleaning_mdi_plus.R @@ -0,0 +1,65 @@ +#### Setup #### +library(magrittr) +if (!require("vdocs")) devtools::install_github("Yu-Group/vdocs") + +data_dir <- file.path("..", "data") + +### Remove duplicate TFs in Enhancer #### +load(file.path(data_dir, "enhancer.Rdata")) + +keep_vars <- varnames.all %>% + dplyr::group_by(Predictor_collapsed) %>% + dplyr::mutate(id = 1:dplyr::n()) %>% + dplyr::filter(id == 1) +write.csv( + X[, keep_vars$Predictor], + file.path(data_dir, "X_enhancer.csv"), + row.names = FALSE +) + +#### Clean covariate matrices #### +X_paths <- c("X_juvenile.csv", + "X_splicing.csv", + "X_ccle_rnaseq.csv", + "X_enhancer.csv") +log_transform <- c("X_splicing.csv", + "X_ccle_rnaseq.csv", + "X_enhancer.csv") + +for (X_path in X_paths) { + X_orig <- data.table::fread(file.path(data_dir, X_path)) %>% + tibble::as_tibble() + + # dim(X_orig) + # sum(is.na(X_orig)) + + X <- X_orig %>% + vdocs::remove_constant_cols(verbose = 2) %>% + vdocs::remove_duplicate_cols(verbose = 2) + + # hist(as.matrix(X)) + if (X_path %in% log_transform) { + X <- log(X + 1) + # hist(as.matrix(X)) + } + + # dim(X) + + write.csv( + X, + file.path(data_dir, "%s_cleaned.csv", fs::path_ext_remove(X_path)), + row.names = FALSE + ) +} + +#### Filter number of features for real data case study #### +X_orig <- data.table::fread(file.path(data_dir, "X_ccle_rnaseq_cleaned.csv")) + +X <- X_orig %>% + vdocs::filter_cols_by_var(max_p = 5000) + +write.csv( + X, + file.path(data_dir, "X_ccle_rnaseq_cleaned_filtered5000.csv"), + row.names = FALSE +) diff --git a/feature_importance/scripts/mda.py b/feature_importance/scripts/mda.py new file mode 100644 index 0000000..e0c59fa --- /dev/null +++ b/feature_importance/scripts/mda.py @@ -0,0 +1,69 @@ +from sklearn.ensemble._forest import _generate_unsampled_indices, _generate_sample_indices +from sklearn.metrics import accuracy_score, mean_squared_error +import numpy as np +import copy +from sklearn.preprocessing import LabelEncoder + + +def MDA(rf, X, y, type = 'oob', n_trials = 10, metric = 'accuracy'): + if len(y.shape) != 2: + raise ValueError('y must be 2d array (n_samples, 1) if numerical or (n_samples, n_categories).') + + y_mda = copy.deepcopy(y) + if rf._estimator_type == "classifier" and y.dtype == "object": + y_mda = LabelEncoder().fit(y_mda.ravel()).transform(y_mda.ravel()).reshape(y_mda.shape[0], 1) + + n_samples, n_features = X.shape + fi_mean = np.zeros((n_features,)) + fi_std = np.zeros((n_features,)) + best_score = rf_accuracy(rf, X, y_mda, type = type, metric = metric) + for f in range(n_features): + permute_score = 0 + permute_std = 0 + X_permute = X.copy() + for i in range(n_trials): + X_permute[:, f] = np.random.permutation(X_permute[:, f]) + to_add = rf_accuracy(rf, X_permute, y_mda, type = type, metric = metric) + permute_score += to_add + permute_std += to_add ** 2 + permute_score /= n_trials + permute_std /= n_trials + permute_std = (permute_std - permute_score ** 2) ** .5 / n_trials ** .5 + fi_mean[f] = best_score - permute_score + fi_std[f] = permute_std + return fi_mean, fi_std + + +def neg_mse(y, y_hat): + return - mean_squared_error(y, y_hat) + + +def rf_accuracy(rf, X, y, type = 'oob', metric = 'accuracy'): + if metric == 'accuracy': + score = accuracy_score + elif metric == 'mse': + score = neg_mse + else: + raise ValueError('metric type not understood') + + n_samples, n_features = X.shape + tmp = 0 + count = 0 + if type == 'test': + return score(y, rf.predict(X)) + elif type == 'train' and not rf.bootstrap: + return score(y, rf.predict(X)) + + for tree in rf.estimators_: + if type == 'oob': + if rf.bootstrap: + indices = _generate_unsampled_indices(tree.random_state, n_samples, n_samples) + else: + raise ValueError('Without bootstrap, it is not possible to calculate oob.') + elif type == 'train': + indices = _generate_sample_indices(tree.random_state, n_samples, n_samples) + else: + raise ValueError('type is not recognized. (%s)'%(type)) + tmp += score(y[indices,:], tree.predict(X[indices, :])) * len(indices) + count += len(indices) + return tmp / count \ No newline at end of file diff --git a/feature_importance/scripts/mdi_oob.py b/feature_importance/scripts/mdi_oob.py new file mode 100644 index 0000000..7c95228 --- /dev/null +++ b/feature_importance/scripts/mdi_oob.py @@ -0,0 +1,282 @@ +import numpy as np +import warnings +import sklearn + +from sklearn.ensemble._forest import ForestClassifier, ForestRegressor +from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, _tree +from distutils.version import LooseVersion +from sklearn.ensemble._forest import _generate_unsampled_indices, _generate_sample_indices +from sklearn.metrics import accuracy_score, mean_squared_error +from sklearn.preprocessing import scale + + +def MDI_OOB(rf, X, y, type='oob', normalized=False, balanced=False, demean=False, normal_fX=False): + n_samples, n_features = X.shape + if len(y.shape) != 2: + raise ValueError('y must be 2d array (n_samples, 1) if numerical or (n_samples, n_categories).') + out = np.zeros((n_features,)) + SE = np.zeros((n_features,)) + if demean: + # demean y + y = y - np.mean(y, axis=0) + + for tree in rf.estimators_: + if type == 'oob': + if rf.bootstrap: + indices = _generate_unsampled_indices(tree.random_state, n_samples, n_samples) + else: + raise ValueError('Without bootstrap, it is not possible to calculate oob.') + elif type == 'test': + indices = np.arange(n_samples) + elif type == 'classic': + if rf.bootstrap: + indices = _generate_sample_indices(tree.random_state, n_samples, n_samples) + else: + indices = np.arange(n_samples) + else: + raise ValueError('type is not recognized. (%s)'%(type)) + _, _, contributions = _predict_tree(tree, X[indices, :]) + if balanced and (type == 'oob' or type == 'test'): + base_indices = _generate_sample_indices(tree.random_state, n_samples, n_samples) + ids = tree.apply(X[indices, :]) + base_ids = tree.apply(X[base_indices, :]) + tmp1, tmp2 = np.unique(ids, return_counts=True) + weight1 = {key: 1. / value for key, value in zip(tmp1, tmp2)} + tmp1, tmp2 = np.unique(base_ids, return_counts=True) + weight2 = {key: value for key, value in zip(tmp1, tmp2)} + final_weights = np.array([[weight1[id] * weight2[id]] for id in ids]) + final_weights /= np.mean(final_weights) + else: + final_weights = 1 + if len(contributions.shape) == 2: + contributions = contributions[:, :, np.newaxis] + # print(contributions.shape, y[indices,:].shape) + if normal_fX: + for k in range(contributions.shape[-1]): + contributions[:, :, k] = scale(contributions[:, :, k]) + if contributions.shape[2] == 2: + contributions = contributions[:, :, 1:] + elif contributions.shape[2] > 2: + raise ValueError('Multi-class y is not currently supported.') + tmp = np.tensordot(np.array(y[indices, :]) * final_weights, contributions, axes=([0, 1], [0, 2])) + if normalized: + # if sum(tmp) != 0: + # out += tmp / sum(tmp) + out += tmp / sum(tmp) + else: + out += tmp / len(indices) + if normalized: + # if sum(tmp) != 0: + # SE += (tmp / sum(tmp)) ** 2 + SE += (tmp / sum(tmp)) ** 2 + else: + SE += (tmp / len(indices)) ** 2 + out /= rf.n_estimators + SE /= rf.n_estimators + SE = ((SE - out ** 2) / rf.n_estimators) ** .5 + return out, SE + + +def _get_tree_paths(tree, node_id, depth=0): + """ + Returns all paths through the tree as list of node_ids + """ + if node_id == _tree.TREE_LEAF: + raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) + + left_child = tree.children_left[node_id] + right_child = tree.children_right[node_id] + + if left_child != _tree.TREE_LEAF: + left_paths = _get_tree_paths(tree, left_child, depth=depth + 1) + right_paths = _get_tree_paths(tree, right_child, depth=depth + 1) + + for path in left_paths: + path.append(node_id) + for path in right_paths: + path.append(node_id) + paths = left_paths + right_paths + else: + paths = [[node_id]] + return paths + + +def _predict_tree(model, X, joint_contribution=False): + """ + For a given DecisionTreeRegressor, DecisionTreeClassifier, + ExtraTreeRegressor, or ExtraTreeClassifier, + returns a triple of [prediction, bias and feature_contributions], such + that prediction ≈ bias + feature_contributions. + """ + leaves = model.apply(X) + paths = _get_tree_paths(model.tree_, 0) + + for path in paths: + path.reverse() + + leaf_to_path = {} + # map leaves to paths + for path in paths: + leaf_to_path[path[-1]] = path + + # remove the single-dimensional inner arrays + values = model.tree_.value.squeeze(axis=1) + # reshape if squeezed into a single float + if len(values.shape) == 0: + values = np.array([values]) + if isinstance(model, DecisionTreeRegressor): + biases = np.full(X.shape[0], values[paths[0][0]]) + line_shape = X.shape[1] + elif isinstance(model, DecisionTreeClassifier): + # scikit stores category counts, we turn them into probabilities + normalizer = values.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + values /= normalizer + + biases = np.tile(values[paths[0][0]], (X.shape[0], 1)) + line_shape = (X.shape[1], model.n_classes_) + else: + warnings.warn('the instance is not recognized. Try to proceed with classifier but could fail.') + normalizer = values.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + values /= normalizer + + biases = np.tile(values[paths[0][0]], (X.shape[0], 1)) + line_shape = (X.shape[1], model.n_classes_) + + direct_prediction = values[leaves] + + # make into python list, accessing values will be faster + values_list = list(values) + feature_index = list(model.tree_.feature) + + contributions = [] + if joint_contribution: + for row, leaf in enumerate(leaves): + path = leaf_to_path[leaf] + + path_features = set() + contributions.append({}) + for i in range(len(path) - 1): + path_features.add(feature_index[path[i]]) + contrib = values_list[path[i + 1]] - \ + values_list[path[i]] + # path_features.sort() + contributions[row][tuple(sorted(path_features))] = \ + contributions[row].get(tuple(sorted(path_features)), 0) + contrib + return direct_prediction, biases, contributions + + else: + unique_leaves = np.unique(leaves) + unique_contributions = {} + + for row, leaf in enumerate(unique_leaves): + for path in paths: + if leaf == path[-1]: + break + + contribs = np.zeros(line_shape) + for i in range(len(path) - 1): + contrib = values_list[path[i + 1]] - \ + values_list[path[i]] + contribs[feature_index[path[i]]] += contrib + unique_contributions[leaf] = contribs + + for row, leaf in enumerate(leaves): + contributions.append(unique_contributions[leaf]) + + return direct_prediction, biases, np.array(contributions) + + +def _predict_forest(model, X, joint_contribution=False): + """ + For a given RandomForestRegressor, RandomForestClassifier, + ExtraTreesRegressor, or ExtraTreesClassifier returns a triple of + [prediction, bias and feature_contributions], such that prediction ≈ bias + + feature_contributions. + """ + biases = [] + contributions = [] + predictions = [] + + if joint_contribution: + + for tree in model.estimators_: + pred, bias, contribution = _predict_tree(tree, X, joint_contribution=joint_contribution) + + biases.append(bias) + contributions.append(contribution) + predictions.append(pred) + + total_contributions = [] + + for i in range(len(X)): + contr = {} + for j, dct in enumerate(contributions): + for k in set(dct[i]).union(set(contr.keys())): + contr[k] = (contr.get(k, 0) * j + dct[i].get(k, 0)) / (j + 1) + + total_contributions.append(contr) + + for i, item in enumerate(contribution): + total_contributions[i] + sm = sum([v for v in contribution[i].values()]) + + return (np.mean(predictions, axis=0), np.mean(biases, axis=0), + total_contributions) + else: + for tree in model.estimators_: + pred, bias, contribution = _predict_tree(tree, X) + + biases.append(bias) + contributions.append(contribution) + predictions.append(pred) + + return (np.mean(predictions, axis=0), np.mean(biases, axis=0), + np.mean(contributions, axis=0)) + + +def predict(model, X, joint_contribution=False): + """ Returns a triple (prediction, bias, feature_contributions), such + that prediction ≈ bias + feature_contributions. + Parameters + ---------- + model : DecisionTreeRegressor, DecisionTreeClassifier, + ExtraTreeRegressor, ExtraTreeClassifier, + RandomForestRegressor, RandomForestClassifier, + ExtraTreesRegressor, ExtraTreesClassifier + Scikit-learn model on which the prediction should be decomposed. + X : array-like, shape = (n_samples, n_features) + Test samples. + + joint_contribution : boolean + Specifies if contributions are given individually from each feature, + or jointly over them + Returns + ------- + decomposed prediction : triple of + * prediction, shape = (n_samples) for regression and (n_samples, n_classes) + for classification + * bias, shape = (n_samples) for regression and (n_samples, n_classes) for + classification + * contributions, If joint_contribution is False then returns and array of + shape = (n_samples, n_features) for regression or + shape = (n_samples, n_features, n_classes) for classification, denoting + contribution from each feature. + If joint_contribution is True, then shape is array of size n_samples, + where each array element is a dict from a tuple of feature indices to + to a value denoting the contribution from that feature tuple. + """ + # Only single out response variable supported, + if model.n_outputs_ > 1: + raise ValueError("Multilabel classification trees not supported") + + if (isinstance(model, DecisionTreeClassifier) or + isinstance(model, DecisionTreeRegressor)): + return _predict_tree(model, X, joint_contribution=joint_contribution) + elif (isinstance(model, ForestClassifier) or + isinstance(model, ForestRegressor)): + return _predict_forest(model, X, joint_contribution=joint_contribution) + else: + raise ValueError("Wrong model type. Base learner needs to be a " + "DecisionTreeClassifier or DecisionTreeRegressor.") diff --git a/feature_importance/scripts/simulations_util.py b/feature_importance/scripts/simulations_util.py new file mode 100644 index 0000000..fab8027 --- /dev/null +++ b/feature_importance/scripts/simulations_util.py @@ -0,0 +1,1019 @@ +import numpy as np +import pandas as pd +import random +from scipy.linalg import toeplitz +import warnings +import math + + +def sample_real_X(fpath=None, X=None, seed=None, normalize=True, + sample_row_n=None, sample_col_n=None, permute_col=True, + signal_features=None, n_signal_features=None, permute_nonsignal_col=None): + """ + :param fpath: path to X data + :param X: data matrix + :param seed: random seed + :param normalize: boolean; whether or not to normalize columns in data to mean 0 and variance 1 + :param sample_row_n: number of samples to subset; default keeps all rows + :param sample_col_n: number of features to subset; default keeps all columns + :param permute_col: boolean; whether or not to permute the columns + :param signal_features: list of features to use as signal features + :param n_signal_features: number of signal features; required if permute_nonsignal_col is not None + :param permute_nonsignal_col: how to permute the nonsignal features; must be one of + [None, "block", "indep", "augment"], where None performs no permutation, "block" performs the permutation + row-wise, "indep" permutes each nonsignal feature column independently, "augment" augments the signal features + with the row-permuted X matrix. + :return: + """ + assert permute_nonsignal_col in [None, "block", "indep", "augment"] + if X is None: + X = pd.read_csv(fpath) + if normalize: + X = (X - X.mean()) / X.std() + if seed is not None: + np.random.seed(seed) + if permute_col: + X = X[np.random.permutation(X.columns)] + if sample_row_n is not None: + keep_idx = np.random.choice(X.shape[0], sample_row_n, replace=False) + X = X.iloc[keep_idx, :] + if sample_col_n is not None: + if signal_features is None: + X = X.sample(n=sample_col_n, replace=False, axis=1) + else: + rand_features = np.random.choice([col for col in X.columns if col not in signal_features], + sample_col_n - len(signal_features), replace=False) + X = X[signal_features + list(rand_features)] + if signal_features is not None: + X = X[signal_features + [col for col in X.columns if col not in signal_features]] + if permute_nonsignal_col is not None: + assert n_signal_features is not None + if permute_nonsignal_col == "block": + X = np.hstack([X.iloc[:, :n_signal_features].to_numpy(), + X.iloc[np.random.permutation(X.shape[0]), n_signal_features:].to_numpy()]) + X = pd.DataFrame(X) + elif permute_nonsignal_col == "indep": + for j in range(n_signal_features, X.shape[1]): + X.iloc[:, j] = np.random.permutation(X.iloc[:, j]) + elif permute_nonsignal_col == "augment": + X = np.hstack([X.iloc[:, :n_signal_features].to_numpy(), + X.iloc[np.random.permutation(X.shape[0]), :].to_numpy()]) + X = IndexedArray(pd.DataFrame(X).to_numpy(), index=keep_idx) + return X + return X.to_numpy() + + +def sample_normal_X(n, d, mean=0, scale=1, corr=0, Sigma=None): + """ + Sample X with iid normal entries + :param n: + :param d: + :param mean: + :param scale: + :param corr: + :param Sigma: + :return: + """ + if Sigma is not None: + if np.isscalar(mean): + mean = np.repeat(mean, d) + X = np.random.multivariate_normal(mean, Sigma, size=n) + elif corr == 0: + X = np.random.normal(mean, scale, size=(n, d)) + else: + Sigma = np.zeros((d, d)) + corr + np.fill_diagonal(Sigma, 1) + if np.isscalar(mean): + mean = np.repeat(mean, d) + X = np.random.multivariate_normal(mean, Sigma, size=n) + return X + + +def sample_block_cor_X(n, d, rho, n_blocks, mean=0): + """ + Sample X from N(mean, Sigma) where Sigma is a block diagonal covariance matrix + :param n: number of samples + :param d: number of features + :param rho: correlation or vector of correlations + :param n_blocks: number of blocks + :param mean: mean of normal distribution + :return: + """ + Sigma = np.zeros((d, d)) + block_size = d // n_blocks + if np.isscalar(rho): + rho = np.repeat(rho, n_blocks) + for i in range(n_blocks): + start = i * block_size + end = (i + 1) * block_size + if i == (n_blocks - 1): + end = d + Sigma[start:end, start:end] = rho[i] + np.fill_diagonal(Sigma, 1) + X = sample_normal_X(n=n, d=d, mean=mean, Sigma=Sigma) + return X + + +def sample_X(support, X_fun, **kwargs): + """ + Wrapper around dgp function for X that reorders columns so support features are in front + :param support: + :param X_fun: + :param kwargs: + :return: + """ + X = X_fun(**kwargs) + for i in range(X.shape[1]): + if i not in support: + support.append(i) + X[:] = X[:, support] + return X + + +def generate_coef(beta, s): + if isinstance(beta, int) or isinstance(beta, float): + beta = np.repeat(beta, repeats=s) + return beta + + +def corrupt_leverage(x_train, y_train, mean_shift, corrupt_quantile, mode="normal"): + assert mode in ["normal", "constant"] + if mean_shift is None: + return y_train + ranked_rows = np.apply_along_axis(np.linalg.norm, axis=1, arr=x_train).argsort().argsort() + low_idx = np.where(ranked_rows < round(corrupt_quantile * len(y_train)))[0] + hi_idx = np.where(ranked_rows >= (len(y_train) - round(corrupt_quantile * len(y_train))))[0] + if mode == "normal": + hi_corrupted = np.random.normal(mean_shift, 1, size=len(hi_idx)) + low_corrupted = np.random.normal(-mean_shift, 1, size=len(low_idx)) + elif mode == "constant": + hi_corrupted = mean_shift + low_corrupted = -mean_shift + y_train[hi_idx] = hi_corrupted + y_train[low_idx] = low_corrupted + return y_train + + +def linear_model(X, sigma, s, beta, heritability=None, snr=None, error_fun=None, + frac_corrupt=None, corrupt_how='permute', corrupt_size=None, + corrupt_mean=None, return_support=False): + """ + This method is used to crete responses from a linear model with hard sparsity + Parameters: + X: X matrix + s: sparsity + beta: coefficient vector. If beta not a vector, then assumed a constant + sigma: s.d. of added noise + Returns: + numpy array of shape (n) + """ + n, p = X.shape + def create_y(x, s, beta): + linear_term = 0 + for j in range(s): + linear_term += x[j] * beta[j] + return linear_term + + beta = generate_coef(beta, s) + y_train = np.array([create_y(X[i, :], s, beta) for i in range(len(X))]) + if heritability is not None: + sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5 + if snr is not None: + sigma = (np.var(y_train) / snr) ** 0.5 + if error_fun is None: + error_fun = np.random.randn + if frac_corrupt is None and corrupt_size is None: + y_train = y_train + sigma * error_fun(n) + else: + if frac_corrupt is None: + frac_corrupt = 0 + num_corrupt = int(np.floor(frac_corrupt*len(y_train))) + corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt) + if corrupt_how == 'permute': + corrupt_array = y_train[corrupt_indices] + corrupt_array = random.sample(list(corrupt_array), len(corrupt_array)) + for i,index in enumerate(corrupt_indices): + y_train[index] = corrupt_array[i] + y_train = y_train + sigma * error_fun(n) + elif corrupt_how == 'cauchy': + for i in range(len(y_train)): + if i in corrupt_indices: + y_train[i] = y_train[i] + sigma*np.random.standard_cauchy() + else: + y_train[i] = y_train[i] + sigma*error_fun() + elif corrupt_how == "leverage_constant": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(s, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant") + elif corrupt_how == "leverage_normal": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(s, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal") + + if return_support: + support = np.concatenate((np.ones(s), np.zeros(X.shape[1] - s))) + return y_train, support, beta + else: + return y_train + + +def lss_model(X, sigma, m, r, tau, beta, heritability=None, snr=None, error_fun=None, min_active=None, + frac_corrupt=None, corrupt_how='permute', corrupt_size=None, corrupt_mean=None, + return_support=False): + """ + This method creates response from an LSS model + + X: data matrix + m: number of interaction terms + r: max order of interaction + tau: threshold + sigma: standard deviation of noise + beta: coefficient vector. If beta not a vector, then assumed a constant + + :return + y_train: numpy array of shape (n) + """ + n, p = X.shape + assert p >= m * r # Cannot have more interactions * size than the dimension + + def lss_func(x, beta): + x_bool = (x - tau) > 0 + y = 0 + for j in range(m): + lss_term_components = x_bool[j * r:j * r + r] + lss_term = int(all(lss_term_components)) + y += lss_term * beta[j] + return y + + def lss_vector_fun(x, beta): + x_bool = (x - tau) > 0 + y = 0 + max_iter = 100 + features = np.arange(p) + support_idx = [] + for j in range(m): + cnt = 0 + while True: + int_features = np.random.choice(features, size=r, replace=False) + lss_term_components = x_bool[:, int_features] + lss_term = np.apply_along_axis(all, 1, lss_term_components) + cnt += 1 + if np.mean(lss_term) >= min_active or cnt > max_iter: + y += lss_term * beta[j] + features = list(set(features).difference(set(int_features))) + support_idx.append(int_features) + if cnt > max_iter: + warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active)) + break + support_idx = np.stack(support_idx).ravel() + support = np.zeros(p) + for j in support_idx: + support[j] = 1 + return y, support + + beta = generate_coef(beta, m) + if tau == 'median': + tau = np.median(X,axis = 0) + + if min_active is None: + y_train = np.array([lss_func(X[i, :], beta) for i in range(n)]) + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + else: + y_train, support = lss_vector_fun(X, beta) + + if heritability is not None: + sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5 + if snr is not None: + sigma = (np.var(y_train) / snr) ** 0.5 + if error_fun is None: + error_fun = np.random.randn + + if frac_corrupt is None and corrupt_size is None: + y_train = y_train + sigma * error_fun(n) + else: + if frac_corrupt is None: + frac_corrupt = 0 + num_corrupt = int(np.floor(frac_corrupt*len(y_train))) + corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt) + if corrupt_how == 'permute': + corrupt_array = y_train[corrupt_indices] + corrupt_array = random.sample(list(corrupt_array), len(corrupt_array)) + for i,index in enumerate(corrupt_indices): + y_train[index] = corrupt_array[i] + y_train = y_train + sigma * error_fun(n) + elif corrupt_how == 'cauchy': + for i in range(len(y_train)): + if i in corrupt_indices: + y_train[i] = y_train[i] + sigma*np.random.standard_cauchy() + else: + y_train[i] = y_train[i] + sigma*error_fun() + elif corrupt_how == "leverage_constant": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(m*r, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant") + elif corrupt_how == "leverage_normal": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(m*r, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal") + + if return_support: + return y_train, support, beta + else: + return y_train + + +def partial_linear_lss_model(X, sigma, s, m, r, tau, beta, heritability=None, snr=None, error_fun=None, + min_active=None, frac_corrupt=None, corrupt_how='permute', corrupt_size=None, + corrupt_mean=None, diagnostics=False, return_support=False): + """ + This method creates response from an linear + lss model + + X: data matrix + m: number of interaction terms + r: max order of interaction + s: denotes number of linear terms in EACH interaction term + tau: threshold + sigma: standard deviation of noise + beta: coefficient vector. If beta not a vector, then assumed a constant + + :return + y_train: numpy array of shape (n) + """ + n, p = X.shape + assert p >= m * r # Cannot have more interactions * size than the dimension + assert s <= r + + def partial_linear_func(x,s,beta): + y = 0.0 + count = 0 + for j in range(m): + for i in range(s): + y += beta[count]*x[j*r+i] + count += 1 + return y + + + def lss_func(x, beta): + x_bool = (x - tau) > 0 + y = 0 + for j in range(m): + lss_term_components = x_bool[j * r:j * r + r] + lss_term = int(all(lss_term_components)) + y += lss_term * beta[j] + return y + + def lss_vector_fun(x, beta, beta_linear): + x_bool = (x - tau) > 0 + y = 0 + max_iter = 100 + features = np.arange(p) + support_idx = [] + for j in range(m): + cnt = 0 + while True: + int_features = np.concatenate( + [np.arange(j*r, j*r+s), np.random.choice(features, size=r-s, replace=False)] + ) + lss_term_components = x_bool[:, int_features] + lss_term = np.apply_along_axis(all, 1, lss_term_components) + cnt += 1 + if np.mean(lss_term) >= min_active or cnt > max_iter: + norm_constant = sum(np.var(x[:, (j*r):(j*r+s)], axis=0) * beta_linear[(j*s):((j+1)*s)]**2) + relative_beta = beta[j] / sum(beta_linear[(j*s):((j+1)*s)]) + y += lss_term * relative_beta * np.sqrt(norm_constant) / np.std(lss_term) + features = list(set(features).difference(set(int_features))) + support_idx.append(int_features) + if cnt > max_iter: + warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active)) + break + support_idx = np.stack(support_idx).ravel() + support = np.zeros(p) + for j in support_idx: + support[j] = 1 + return y, support + + beta_lss = generate_coef(beta, m) + beta_linear = generate_coef(beta, s*m) + if tau == 'median': + tau = np.median(X,axis = 0) + + y_train_linear = np.array([partial_linear_func(X[i, :],s,beta_linear ) for i in range(n)]) + if min_active is None: + y_train_lss = np.array([lss_func(X[i, :], beta_lss) for i in range(n)]) + support = np.concatenate((np.ones(max(m * r, s)), np.zeros(X.shape[1] - max((m * r), s)))) + else: + y_train_lss, support = lss_vector_fun(X, beta_lss, beta_linear) + y_train = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + if heritability is not None: + sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5 + if snr is not None: + sigma = (np.var(y_train) / snr) ** 0.5 + if error_fun is None: + error_fun = np.random.randn + + if frac_corrupt is None and corrupt_size is None: + y_train = y_train + sigma * error_fun(n) + else: + if frac_corrupt is None: + frac_corrupt = 0 + num_corrupt = int(np.floor(frac_corrupt*len(y_train))) + corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt) + if corrupt_how == 'permute': + corrupt_array = y_train[corrupt_indices] + corrupt_array = random.sample(list(corrupt_array), len(corrupt_array)) + for i,index in enumerate(corrupt_indices): + y_train[index] = corrupt_array[i] + y_train = y_train + sigma * error_fun(n) + elif corrupt_how == 'cauchy': + for i in range(len(y_train)): + if i in corrupt_indices: + y_train[i] = y_train[i] + sigma*np.random.standard_cauchy() + else: + y_train[i] = y_train[i] + sigma*error_fun() + elif corrupt_how == "leverage_constant": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(max(m*r, s), p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant") + elif corrupt_how == "leverage_normal": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(max(m*r, s), p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal") + + if return_support: + return y_train, support, beta_lss + elif diagnostics: + return y_train, y_train_linear, y_train_lss + else: + return y_train + + +def hierarchical_poly(X, sigma=None, m=1, r=1, beta=1, heritability=None, snr=None, + frac_corrupt=None, corrupt_how='permute', corrupt_size=None, + corrupt_mean=None, error_fun=None, return_support=False): + """ + This method creates response from an Linear + LSS model + + X: data matrix + m: number of interaction terms + r: max order of interaction + s: sparsity + sigma: standard deviation of noise + beta: coefficient vector. If beta not a vector, then assumed a constant + + :return + y_train: numpy array of shape (n) + """ + + n, p = X.shape + assert p >= m * r + + def reg_func(x, beta): + y = 0 + for i in range(m): + hier_term = 1.0 + for j in range(r): + hier_term += x[i * r + j] * hier_term + y += hier_term * beta[i] + return y + + beta = generate_coef(beta, m) + y_train = np.array([reg_func(X[i, :], beta) for i in range(n)]) + if heritability is not None: + sigma = (np.var(y_train) * ((1.0 - heritability) / heritability)) ** 0.5 + if snr is not None: + sigma = (np.var(y_train) / snr) ** 0.5 + if error_fun is None: + error_fun = np.random.randn + + if frac_corrupt is None and corrupt_size is None: + y_train = y_train + sigma * error_fun(n) + else: + if frac_corrupt is None: + frac_corrupt = 0 + num_corrupt = int(np.floor(frac_corrupt*len(y_train))) + corrupt_indices = random.sample([*range(len(y_train))], k=num_corrupt) + if corrupt_how == 'permute': + corrupt_array = y_train[corrupt_indices] + corrupt_array = random.sample(list(corrupt_array), len(corrupt_array)) + for i,index in enumerate(corrupt_indices): + y_train[index] = corrupt_array[i] + y_train = y_train + sigma * error_fun(n) + elif corrupt_how == 'cauchy': + for i in range(len(y_train)): + if i in corrupt_indices: + y_train[i] = y_train[i] + sigma*np.random.standard_cauchy() + else: + y_train[i] = y_train[i] + sigma*error_fun() + elif corrupt_how == "leverage_constant": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(m*r, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="constant") + elif corrupt_how == "leverage_normal": + if isinstance(corrupt_size, int): + corrupt_quantile = corrupt_size / n + else: + corrupt_quantile = corrupt_size + y_train = y_train + sigma * error_fun(n) + corrupt_idx = np.random.choice(range(m*r, p), size=1) + y_train = corrupt_leverage(X[:, corrupt_idx], y_train, mean_shift=corrupt_mean, corrupt_quantile=corrupt_quantile, mode="normal") + + if return_support: + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + return y_train, support, beta + else: + return y_train + + +def logistic_model(X, s, beta=None, beta_grid=np.logspace(-4, 4, 100), heritability=None, + frac_label_corruption=None, return_support=False): + """ + This method is used to create responses from a sum of squares model with hard sparsity + Parameters: + X: X matrix + s: sparsity + beta: coefficient vector. If beta not a vector, then assumed a constant + Returns: + numpy array of shape (n) + """ + + def create_prob(x, beta): + linear_term = 0 + for j in range(len(beta)): + linear_term += x[j] * beta[j] + prob = 1 / (1 + np.exp(-linear_term)) + return prob + + def create_y(x, beta): + linear_term = 0 + for j in range(len(beta)): + linear_term += x[j] * beta[j] + prob = 1 / (1 + np.exp(-linear_term)) + return (np.random.uniform(size=1) < prob) * 1 + + if heritability is None: + beta = generate_coef(beta, s) + y_train = np.array([create_y(X[i, :], beta) for i in range(len(X))]).ravel() + else: + # find beta to get desired heritability via adaptive grid search within eps=0.01 + y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, s, create_prob, beta_grid) + + if frac_label_corruption is None: + y_train = y_train + else: + corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train))) + y_train[corrupt_indices] = 1 - y_train[corrupt_indices] + if return_support: + support = np.concatenate((np.ones(s), np.zeros(X.shape[1] - s))) + return y_train, support, beta + else: + return y_train + + +def logistic_lss_model(X, m, r, tau, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100), + min_active=None, frac_label_corruption=None, return_support=False): + """ + This method is used to create responses from a logistic model model with lss + X: X matrix + s: sparsity + beta: coefficient vector. If beta not a vector, then assumed a constant + Returns: + numpy array of shape (n) + """ + n, p = X.shape + + def lss_prob_func(x, beta): + x_bool = (x - tau) > 0 + y = 0 + for j in range(m): + lss_term_components = x_bool[j * r:j * r + r] + lss_term = int(all(lss_term_components)) + y += lss_term * beta[j] + prob = 1 / (1 + np.exp(-y)) + return prob + + def lss_func(x, beta): + x_bool = (x - tau) > 0 + y = 0 + for j in range(m): + lss_term_components = x_bool[j * r:j * r + r] + lss_term = int(all(lss_term_components)) + y += lss_term * beta[j] + prob = 1 / (1 + np.exp(-y)) + return (np.random.uniform(size=1) < prob) * 1 + + def lss_vector_fun(x, beta): + x_bool = (x - tau) > 0 + y = 0 + max_iter = 100 + features = np.arange(p) + support_idx = [] + for j in range(m): + cnt = 0 + while True: + int_features = np.random.choice(features, size=r, replace=False) + lss_term_components = x_bool[:, int_features] + lss_term = np.apply_along_axis(all, 1, lss_term_components) + cnt += 1 + if np.mean(lss_term) >= min_active or cnt > max_iter: + y += lss_term * beta[j] + features = list(set(features).difference(set(int_features))) + support_idx.append(int_features) + if cnt > max_iter: + warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active)) + break + prob = 1 / (1 + np.exp(-y)) + y = (np.random.uniform(size=n) < prob) * 1 + support_idx = np.stack(support_idx).ravel() + support = np.zeros(p) + for j in support_idx: + support[j] = 1 + return y, support + + if tau == 'median': + tau = np.median(X,axis = 0) + + if heritability is None: + beta = generate_coef(beta, m) + if min_active is None: + y_train = np.array([lss_func(X[i, :], beta) for i in range(n)]).ravel() + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + else: + y_train, support = lss_vector_fun(X, beta) + y_train = y_train.ravel() + else: + if min_active is not None: + raise ValueError("Cannot set heritability and min_active at the same time.") + # find beta to get desired heritability via adaptive grid search within eps=0.01 (need to jitter beta to reach higher signals) + y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, m, lss_prob_func, beta_grid, jitter_beta=True) + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + + if frac_label_corruption is None: + y_train = y_train + else: + corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train))) + y_train[corrupt_indices] = 1 - y_train[corrupt_indices] + + if return_support: + return y_train, support, beta + else: + return y_train + + +def logistic_partial_linear_lss_model(X, s, m, r, tau, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100), + min_active=None, frac_label_corruption=None, return_support=False): + """ + This method is used to create responses from a logistic model model with lss + X: X matrix + s: sparsity + beta: coefficient vector. If beta not a vector, then assumed a constant + Returns: + numpy array of shape (n) + """ + n, p = X.shape + assert p >= m * r + + def partial_linear_func(x,s,beta): + y = 0.0 + count = 0 + for j in range(m): + for i in range(s): + y += beta[count]*x[j*r+i] + count += 1 + return y + + def lss_func(x, beta): + x_bool = (x - tau) > 0 + y = 0 + for j in range(m): + lss_term_components = x_bool[j * r:j * r + r] + lss_term = int(all(lss_term_components)) + y += lss_term * beta[j] + return y + + def logistic_link_func(y): + prob = 1 / (1 + np.exp(-y)) + return (np.random.uniform(size=1) < prob) * 1 + + def logistic_prob_func(y): + prob = 1 / (1 + np.exp(-y)) + return prob + + def lss_vector_fun(x, beta, beta_linear): + x_bool = (x - tau) > 0 + y = 0 + max_iter = 100 + features = np.arange(p) + support_idx = [] + for j in range(m): + cnt = 0 + while True: + int_features = np.concatenate( + [np.arange(j*r, j*r+s), np.random.choice(features, size=r-s, replace=False)] + ) + lss_term_components = x_bool[:, int_features] + lss_term = np.apply_along_axis(all, 1, lss_term_components) + cnt += 1 + if np.mean(lss_term) >= min_active or cnt > max_iter: + norm_constant = sum(np.var(x[:, (j*r):(j*r+s)], axis=0) * beta_linear[(j*s):((j+1)*s)]**2) + relative_beta = beta[j] / sum(beta_linear[(j*s):((j+1)*s)]) + y += lss_term * relative_beta * np.sqrt(norm_constant) / np.std(lss_term) + features = list(set(features).difference(set(int_features))) + support_idx.append(int_features) + if cnt > max_iter: + warnings.warn("Could not find interaction {} with min active >= {}".format(j, min_active)) + break + support_idx = np.stack(support_idx).ravel() + support = np.zeros(p) + for j in support_idx: + support[j] = 1 + return y, support + + if tau == 'median': + tau = np.median(X,axis = 0) + + if heritability is None: + beta_lss = generate_coef(beta, m) + beta_linear = generate_coef(beta, s*m) + + y_train_linear = np.array([partial_linear_func(X[i, :],s,beta_linear ) for i in range(n)]) + if min_active is None: + y_train_lss = np.array([lss_func(X[i, :], beta_lss) for i in range(n)]) + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + else: + y_train_lss, support = lss_vector_fun(X, beta_lss, beta_linear) + y_train = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + y_train = np.array([logistic_link_func(y_train[i]) for i in range(n)]) + else: + if min_active is not None: + raise ValueError("Cannot set heritability and min_active at the same time.") + # find beta to get desired heritability via adaptive grid search within eps=0.01 + eps = 0.01 + max_iter = 1000 + pves = {} + for idx, beta in enumerate(beta_grid): + beta_lss_vec = generate_coef(beta, m) + beta_linear_vec = generate_coef(beta, s*m) + + y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)]) + y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)]) + y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel() + np.random.seed(idx) + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + pves[(idx, beta)] = pve + + (idx, beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability)) + beta_lss_vec = generate_coef(beta, m) + beta_linear_vec = generate_coef(beta, s*m) + + y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)]) + y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)]) + y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + + prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel() + np.random.seed(idx) + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + if pve > heritability: + min_beta = beta_grid[idx-1] + max_beta = beta + else: + min_beta = beta + max_beta = beta_grid[idx+1] + cur_beta = (min_beta + max_beta) / 2 + iter = 1 + while np.abs(pve - heritability) > eps: + beta_lss_vec = generate_coef(cur_beta, m) + beta_linear_vec = generate_coef(cur_beta, s*m) + + y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)]) + y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)]) + y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + + prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel() + np.random.seed(iter + len(beta_grid)) + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + pves[(iter + len(beta_grid), cur_beta)] = pve + if pve > heritability: + max_beta = cur_beta + else: + min_beta = cur_beta + beta = cur_beta + cur_beta = (min_beta + max_beta) / 2 + iter += 1 + if iter > max_iter: + (idx, cur_beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability)) + beta_lss_vec = generate_coef(cur_beta, m) + beta_linear_vec = generate_coef(cur_beta, s*m) + + y_train_linear = np.array([partial_linear_func(X[i, :], s, beta_linear_vec) for i in range(n)]) + y_train_lss = np.array([lss_func(X[i, :], beta_lss_vec) for i in range(n)]) + y_train_sum = np.array([y_train_linear[i] + y_train_lss[i] for i in range(n)]) + + prob_train = np.array([logistic_prob_func(y_train_sum[i]) for i in range(n)]).ravel() + np.random.seed(idx) + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + beta = cur_beta + break + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + + if frac_label_corruption is None: + y_train = y_train + else: + corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train))) + y_train[corrupt_indices] = 1 - y_train[corrupt_indices] + + y_train = y_train.ravel() + + if return_support: + return y_train, support, beta + else: + return y_train + + +def logistic_hier_model(X, m, r, beta=None, heritability=None, beta_grid=np.logspace(-4, 4, 100), + frac_label_corruption=None, return_support=False): + + n, p = X.shape + assert p >= m * r + + def reg_func(x, beta): + y = 0 + for i in range(m): + hier_term = 1.0 + for j in range(r): + hier_term += x[i * r + j] * hier_term + y += hier_term * beta[i] + return y + + def logistic_link_func(y): + prob = 1 / (1 + np.exp(-y)) + return (np.random.uniform(size=1) < prob) * 1 + + def prob_func(x, beta): + y = 0 + for i in range(m): + hier_term = 1.0 + for j in range(r): + hier_term += x[i * r + j] * hier_term + y += hier_term * beta[i] + return 1 / (1 + np.exp(-y)) + + if heritability is None: + beta = generate_coef(beta, m) + y_train = np.array([reg_func(X[i, :], beta) for i in range(n)]) + y_train = np.array([logistic_link_func(y_train[i]) for i in range(n)]) + else: + # find beta to get desired heritability via adaptive grid search within eps=0.01 + y_train, beta, heritability, hdict = logistic_heritability_search(X, heritability, m, prob_func, beta_grid) + + if frac_label_corruption is None: + y_train = y_train + else: + corrupt_indices = np.random.choice(np.arange(len(y_train)), size=math.ceil(frac_label_corruption*len(y_train))) + y_train[corrupt_indices] = 1 - y_train[corrupt_indices] + y_train = y_train.ravel() + + if return_support: + support = np.concatenate((np.ones(m * r), np.zeros(X.shape[1] - (m * r)))) + return y_train, support, beta + else: + return y_train + + +def logistic_heritability_search(X, heritability, s, prob_fun, beta_grid=np.logspace(-4, 4, 100), + eps=0.01, max_iter=1000, jitter_beta=False, return_pve=True): + pves = {} + + # first search over beta grid + for idx, beta in enumerate(beta_grid): + np.random.seed(idx) + beta_vec = generate_coef(beta, s) + if jitter_beta: + beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape) + prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel() + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + pves[(idx, beta)] = pve + + # find beta with heritability closest to desired heritability + (idx, beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability)) + np.random.seed(idx) + beta_vec = generate_coef(beta, s) + if jitter_beta: + beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape) + prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel() + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + + # search nearby beta to get closer to desired heritability + if pve > heritability: + min_beta = beta_grid[idx-1] + max_beta = beta + else: + min_beta = beta + max_beta = beta_grid[idx+1] + cur_beta = (min_beta + max_beta) / 2 + iter = 1 + while np.abs(pve - heritability) > eps: + np.random.seed(iter + len(beta_grid)) + beta_vec = generate_coef(cur_beta, s) + if jitter_beta: + beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape) + prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel() + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + pves[(iter + len(beta_grid), cur_beta)] = pve + if pve > heritability: + max_beta = cur_beta + else: + min_beta = cur_beta + cur_beta = (min_beta + max_beta) / 2 + beta = beta_vec + iter += 1 + if iter > max_iter: + (idx, cur_beta), pve = min(pves.items(), key=lambda x: abs(x[1] - heritability)) + np.random.seed(idx) + beta_vec = generate_coef(cur_beta, s) + if jitter_beta: + beta_vec = beta_vec + np.random.uniform(-1e-4, 1e-4, beta_vec.shape) + prob_train = np.array([prob_fun(X[i, :], beta_vec) for i in range(len(X))]).ravel() + y_train = (np.random.uniform(size=len(prob_train)) < prob_train) * 1 + pve = np.var(prob_train) / np.var(y_train) + beta = beta_vec + break + + if return_pve: + return y_train, beta, pve, pves + else: + return y_train, beta + + +def entropy_X(n, scale=False): + x1 = np.random.choice([0, 1], (n, 1), replace=True) + x2 = np.random.normal(0, 1, (n, 1)) + x3 = np.random.choice(np.arange(4), (n, 1), replace=True) + x4 = np.random.choice(np.arange(10), (n, 1), replace=True) + x5 = np.random.choice(np.arange(20), (n, 1), replace=True) + X = np.concatenate((x1, x2, x3, x4, x5), axis=1) + if scale: + X = (X - X.mean()) / X.std() + return X + + +def entropy_y(X, c=3, return_support=False): + if any(X[:, 0] < 0): + x = (X[:, 0] > 0) * 1 + else: + x = X[:, 0] + prob = ((c - 2) * x + 1) / c + y = (np.random.uniform(size=len(prob)) < prob) * 1 + if return_support: + support = np.array([0, 1, 0, 0, 0]) + beta = None + return y, support, beta + else: + return y + + +class IndexedArray(np.ndarray): + def __new__(cls, input_array, index=None): + obj = np.asarray(input_array).view(cls) + obj.index = index + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.index = getattr(obj, 'index', None) + +#%% diff --git a/feature_importance/scripts/viz.R b/feature_importance/scripts/viz.R new file mode 100644 index 0000000..c61cee0 --- /dev/null +++ b/feature_importance/scripts/viz.R @@ -0,0 +1,781 @@ +library(magrittr) + + +# reformat results +reformat_results <- function(results, prediction = FALSE) { + if (!prediction) { + results_grouped <- results %>% + dplyr::group_by(index) %>% + tidyr::nest(fi_scores = var:(tidyselect::last_col())) %>% + dplyr::ungroup() %>% + dplyr::select(-index) %>% + # join fi+model to get method column + tidyr::unite(col = "method", fi, model, na.rm = TRUE, remove = FALSE) %>% + dplyr::mutate( + # get rid of duplicate RF in r2f method name + method = ifelse(stringr::str_detect(method, "^r2f.*RF$"), + stringr::str_remove(method, "\\_RF$"), method), + # unnest columns + prediction_score = purrr::map_dbl( + fi_scores, + function(x) { + ifelse("prediction_score" %in% colnames(x), x$prediction_score[[1]], NA) + } + ), + tauAP = purrr::map_dbl( + fi_scores, + function(x) { + ifelse("tauAP" %in% colnames(x), x$tauAP[[1]], NA) + } + ), + RBO = purrr::map_dbl( + fi_scores, + function(x) { + ifelse("RBO" %in% colnames(x), x$RBO[[1]], NA) + } + ) + ) + } else { + results_grouped <- results %>% + dplyr::group_by(index) %>% + tidyr::nest(predictions = sample_id:(tidyselect::last_col())) %>% + dplyr::ungroup() %>% + dplyr::select(-index) %>% + dplyr::mutate( + # get rid of duplicate RF in r2f method name + method = ifelse(stringr::str_detect(model, "^r2f.*RF$"), + stringr::str_remove(model, "\\_RF$"), model) + ) + } + return(results_grouped) +} + +# plot metrics (mean value across repetitions with error bars) +plot_metrics <- function(results, + metric = c("rocauc", "prauc"),# "tpr", + #"median_signal_rank", "max_signal_rank"), + x_str, facet_str, linetype_str = NULL, + point_size = 1, line_size = 1, errbar_width = 0, + alpha = 0.5, inside_legend = FALSE, + manual_color_palette = NULL, + show_methods = NULL, + method_labels = ggplot2::waiver(), + alpha_values = NULL, + legend_position = NULL, + custom_theme = vthemes::theme_vmodern(size_preset = "medium")) { + if (is.null(show_methods)) { + show_methods <- sort(unique(results$method)) + } + metric_names <- metric + plt_df <- results %>% + dplyr::select(rep, method, + tidyselect::all_of(c(metric, x_str, facet_str, linetype_str))) %>% + tidyr::pivot_longer( + cols = tidyselect::all_of(metric), names_to = "metric" + ) %>% + dplyr::group_by( + method, metric, dplyr::across(tidyselect::all_of(c(x_str, facet_str, linetype_str))) + ) %>% + dplyr::summarise(mean = mean(value), + sd = sd(value) / sqrt(dplyr::n()), + .groups = "keep") %>% + dplyr::filter(method %in% show_methods) %>% + dplyr::mutate( + method = factor(method, levels = show_methods), + metric = forcats::fct_recode( + factor(metric, levels = metric_names), + AUROC = "rocauc", AUPRC = "prauc", Accuracy = "accuracy" + ) + ) + + if (is.null(linetype_str)) { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::geom_point( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = method), + size = point_size + ) + + ggplot2::geom_line( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = method), + size = line_size + ) + + ggplot2::geom_errorbar( + ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd, + color = method, alpha = method, group = method), + width = errbar_width, show_guide = FALSE + ) + } else { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::geom_point( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str))), + size = point_size + ) + + ggplot2::geom_line( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str)), + linetype = !!rlang::sym(linetype_str)), + size = line_size + ) + + ggplot2::geom_errorbar( + ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd, + color = method, alpha = method, group = interaction(method, !!rlang::sym(linetype_str))), + width = errbar_width, show_guide = FALSE + ) + } + if (!is.null(manual_color_palette)) { + if (is.null(alpha_values)) { + alpha_values <- c(1, rep(alpha, length(method_labels) - 1)) + } + plt <- plt + + ggplot2::scale_color_manual( + values = manual_color_palette, labels = method_labels + ) + + ggplot2::scale_alpha_manual( + values = alpha_values, labels = method_labels + ) + } + if (!is.null(custom_theme)) { + plt <- plt + custom_theme + } + + if (!is.null(facet_str)) { + plt <- plt + + ggplot2::facet_grid(reformulate(facet_str, "metric"), scales = "free") + } else if (length(metric) > 1) { + plt <- plt + + ggplot2::facet_wrap(~ metric, scales = "free") + } + + if (inside_legend) { + if (is.null(legend_position)) { + legend_position <- c(0.75, 0.3) + } + plt <- plt + + ggplot2::theme( + legend.title = ggplot2::element_blank(), + legend.position = legend_position, + legend.background = ggplot2::element_rect( + color = "slategray", size = 0.1 + ) + ) + } + + return(plt) +} + +# plot restricted metrics (mean value across repetitions with error bars) +plot_restricted_metrics <- function(results, metric = c("rocauc", "prauc"), + x_str, facet_str, + quantiles = c(0.1, 0.2, 0.3, 0.4), + point_size = 1, line_size = 1, errbar_width = 0, + alpha = 0.5, inside_legend = FALSE, + manual_color_palette = NULL, + show_methods = NULL, + method_labels = ggplot2::waiver(), + custom_theme = vthemes::theme_vmodern(size_preset = "medium")) { + if (is.null(show_methods)) { + show_methods <- sort(unique(results$method)) + } + results <- results %>% + dplyr::select(rep, method, fi_scores, + tidyselect::all_of(c(x_str, facet_str))) %>% + dplyr::mutate( + vars_ordered = purrr::map( + fi_scores, + function(fi_df) { + fi_df %>% + dplyr::filter(!is.na(cor_with_signal)) %>% + dplyr::arrange(-cor_with_signal) %>% + dplyr::pull(var) + } + ) + ) + + plt_df_ls <- list() + for (q in quantiles) { + plt_df_ls[[as.character(q)]] <- results %>% + dplyr::mutate( + restricted_metrics = purrr::map2_dfr( + fi_scores, vars_ordered, + function(fi_df, ignore_vars) { + ignore_vars <- ignore_vars[1:round(q * length(ignore_vars))] + auroc_r <- fi_df %>% + dplyr::filter(!(var %in% ignore_vars)) %>% + yardstick::roc_auc( + truth = factor(true_support, levels = c("1", "0")), importance, + event_level = "first" + ) %>% + dplyr::pull(.estimate) + auprc_r <- fi_df %>% + dplyr::filter(!(var %in% ignore_vars)) %>% + yardstick::pr_auc( + truth = factor(true_support, levels = c("1", "0")), importance, + event_level = "first" + ) %>% + dplyr::pull(.estimate) + return(data.frame(restricted_auroc = auroc_r, + restricted_auprc = auprc_r)) + } + ) + ) %>% + tidyr::unnest(restricted_metrics) %>% + tidyr::pivot_longer( + cols = c(restricted_auroc, restricted_auprc), names_to = "metric" + ) %>% + dplyr::group_by( + method, metric, dplyr::across(tidyselect::all_of(c(x_str, facet_str))) + ) %>% + dplyr::summarise(mean = mean(value), + sd = sd(value) / sqrt(dplyr::n()), + .groups = "keep") %>% + dplyr::ungroup() %>% + dplyr::filter(method %in% show_methods) %>% + dplyr::mutate( + method = factor(method, levels = show_methods), + metric = forcats::fct_recode( + factor(metric, levels = c("restricted_auroc", "restricted_auprc")), + `Restricted AUROC` = "restricted_auroc", + `Restricted AUPRC` = "restricted_auprc" + ) + ) + } + + plt_df <- purrr::map_dfr(plt_df_ls, ~.x, .id = ".threshold") %>% + dplyr::mutate(.threshold = as.numeric(.threshold)) + + plt <- ggplot2::ggplot(plt_df) + + ggplot2::geom_point( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = method), + size = point_size + ) + + ggplot2::geom_line( + ggplot2::aes(x = .data[[x_str]], y = mean, + color = method, alpha = method, group = method), + size = line_size + ) + + ggplot2::geom_errorbar( + ggplot2::aes(x = .data[[x_str]], ymin = mean - sd, ymax = mean + sd, + color = method, alpha = method, group = method), + width = errbar_width, show_guide = FALSE + ) + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual( + values = manual_color_palette, labels = method_labels + ) + + ggplot2::scale_alpha_manual( + values = c(1, rep(alpha, length(method_labels) - 1)), + labels = method_labels + ) + } + if (!is.null(custom_theme)) { + plt <- plt + custom_theme + } + + if (!is.null(facet_str)) { + formula <- sprintf("metric + .threshold ~ %s", paste0(facet_str, collapse = " + ")) + plt <- plt + + ggplot2::facet_grid(as.formula(formula)) + } else { + plt <- plt + + ggplot2::facet_wrap(.threshold ~ metric, scales = "free") + } + + if (inside_legend) { + plt <- plt + + ggplot2::theme( + legend.title = ggplot2::element_blank(), + legend.position = c(0.75, 0.3), + legend.background = ggplot2::element_rect( + color = "slategray", size = 0.1 + ) + ) + } + + return(plt) +} + +# plot true positive rate across # positives +plot_tpr <- function(results, facet_vars, point_size = 0.85, + manual_color_palette = NULL, + show_methods = NULL, + method_labels = ggplot2::waiver(), + custom_theme = vthemes::theme_vmodern(size_preset = "medium")) { + if (is.null(results)) { + return(NULL) + } + + if (is.null(show_methods)) { + show_methods <- sort(unique(results$method)) + } + + facet_names <- names(facet_vars) + names(facet_vars) <- NULL + + plt_df <- results %>% + dplyr::mutate( + fi_scores = mapply(name = fi, scores_df = fi_scores, + function(name, scores_df) { + scores_df <- scores_df %>% + dplyr::mutate( + ranking = rank(-importance, + ties.method = "random") + ) %>% + dplyr::arrange(ranking) %>% + dplyr::mutate( + .tp = cumsum(true_support) / sum(true_support) + ) + return(scores_df) + }, SIMPLIFY = FALSE) + ) %>% + tidyr::unnest(fi_scores) %>% + dplyr::select(tidyselect::all_of(facet_vars), rep, method, ranking, .tp) %>% + dplyr::group_by( + dplyr::across(tidyselect::all_of(facet_vars)), method, ranking + ) %>% + dplyr::summarise(.tp = mean(.tp), .groups = "keep") %>% + dplyr::mutate(method = factor(method, levels = show_methods)) %>% + dplyr::ungroup() + + if (!is.null(facet_names)) { + for (i in 1:length(facet_vars)) { + facet_var <- facet_vars[i] + facet_name <- facet_names[i] + if (facet_name != "") { + plt_df <- plt_df %>% + dplyr::mutate(dplyr::across( + tidyselect::all_of(facet_var), + ~factor(sprintf("%s = %s", facet_name, .x), + levels = sprintf("%s = %s", facet_name, sort(unique(.x)))) + )) + } + } + } + + if (length(facet_vars) == 1) { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = ranking, y = .tp, color = method) + + ggplot2::geom_line(size = point_size) + + ggplot2::facet_wrap(reformulate(facet_vars)) + } else { + plt <- ggplot2::ggplot(plt_df) + + ggplot2::aes(x = ranking, y = .tp, color = method) + + ggplot2::geom_line(size = point_size) + + ggplot2::facet_grid(reformulate(facet_vars[1], facet_vars[2])) + } + plt <- plt + + ggplot2::labs(x = "Top n", y = "True Positive Rate", + fill = "Method", color = "Method") + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual( + values = manual_color_palette, labels = method_labels + ) + } + if (!is.null(custom_theme)) { + plt <- plt + custom_theme + } + + return(plt) +} + +# plot stability results +plot_perturbation_stability <- function(results, + facet_rows = "heritability_name", + facet_cols = "rho_name", + param_name = NULL, + facet_row_names = "Avg Rank (PVE = %s)", + group_fun = NULL, + descending_methods = NULL, + manual_color_palette = NULL, + show_methods = NULL, + method_labels = ggplot2::waiver(), + plot_types = c("boxplot", "errbar"), + save_dir = ".", + save_filename = NULL, + fig_height = 11, + fig_width = 11, + ...) { + plot_types <- match.arg(plot_types, several.ok = TRUE) + + my_theme <- vthemes::theme_vmodern( + size_preset = "medium", bg_color = "white", grid_color = "white", + axis.title = ggplot2::element_text(size = 12, face = "plain"), + legend.title = ggplot2::element_blank(), + legend.text = ggplot2::element_text(size = 9), + legend.text.align = 0, + plot.title = ggplot2::element_blank() + ) + + if (is.null(group_fun)) { + group_fun <- function(var, sig_ids, cnsig_ids) { + dplyr::case_when( + var %in% sig_ids ~ "Sig", + var %in% cnsig_ids ~ "C-NSig", + TRUE ~ "NSig" + ) %>% + factor(levels = c("Sig", "C-NSig", "NSig")) + } + } + + if (!is.null(show_methods)) { + results <- results %>% + dplyr::filter(fi %in% show_methods) + if (!identical(method_labels, ggplot2::waiver())) { + method_names <- show_methods + names(method_names) <- method_labels + results$fi <- do.call(forcats::fct_recode, + args = c(list(results$fi), as.list(method_names))) + results$fi <- factor(results$fi, levels = method_labels) + method_labels <- ggplot2::waiver() + } + } + + if (!is.null(descending_methods)) { + results <- results %>% + dplyr::mutate( + importance = ifelse(fi %in% descending_methods, -importance, importance) + ) + } + + rankings <- results %>% + dplyr::group_by( + rep, fi, + dplyr::across(tidyselect::all_of(c(facet_rows, facet_cols, param_name))), + ) %>% + dplyr::mutate( + rank = rank(-importance), + group = group_fun(var, ...) + ) %>% + dplyr::ungroup() + + agg_rankings <- rankings %>% + dplyr::group_by( + rep, fi, group, + dplyr::across(tidyselect::all_of(c(facet_rows, facet_cols, param_name))) + ) %>% + dplyr::summarise( + avgrank = mean(rank), + .groups = "keep" + ) %>% + dplyr::ungroup() + + ymin <- min(agg_rankings$avgrank) + ymax <- max(agg_rankings$avgrank) + + for (type in plot_types) { + plt_ls <- list() + for (val in unique(agg_rankings[[facet_rows]])) { + if (identical(type, "boxplot")) { + plt <- agg_rankings %>% + dplyr::filter(.data[[facet_rows]] == val) %>% + ggplot2::ggplot() + + ggplot2::aes(x = group, y = avgrank, color = fi) + + ggplot2::geom_boxplot() + } else if (identical(type, "errbar")) { + plt <- agg_rankings %>% + dplyr::filter(.data[[facet_rows]] == val) %>% + dplyr::group_by( + fi, group, dplyr::across(tidyselect::all_of(facet_cols)) + ) %>% + dplyr::summarise( + .mean = mean(avgrank), + .sd = sd(avgrank), + .groups = "keep" + ) %>% + dplyr::ungroup() %>% + ggplot2::ggplot() + + # ggplot2::geom_point( + # ggplot2::aes(x = group, y = .mean, color = fi, group = fi), + # position = ggplot2::position_dodge2(width = 0.8, padding = 0.8) + # ) + + ggplot2::geom_errorbar( + ggplot2::aes(x = group, ymin = .mean - .sd, ymax = .mean + .sd, + color = fi, group = fi), + position = ggplot2::position_dodge2(width = 0, padding = 0.5), + width = 0.5, show_guide = FALSE + ) + } + plt <- plt + + ggplot2::coord_cartesian(ylim = c(ymin, ymax)) + + ggplot2::facet_grid(~ .data[[facet_cols]], labeller = ggplot2::label_parsed) + + my_theme + + ggplot2::theme( + panel.grid.major = ggplot2::element_line(colour = "#d9d9d9"), + panel.grid.major.x = ggplot2::element_blank(), + panel.grid.minor.x = ggplot2::element_blank(), + axis.line.y = ggplot2::element_blank(), + axis.ticks.y = ggplot2::element_blank(), + legend.position = "right" + ) + + ggplot2::labs(x = "Feature Groups", y = sprintf(facet_row_names, val)) + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels) + } + if (length(plt_ls) != 0) { + plt <- plt + + ggplot2::theme(strip.text = ggplot2::element_blank()) + } + plt_ls[[as.character(val)]] <- plt + } + agg_plt <- patchwork::wrap_plots(plt_ls) + + patchwork::plot_layout(ncol = 1, guides = "collect") + if (!is.null(save_filename)) { + ggplot2::ggsave( + filename = file.path(save_dir, + sprintf("%s_%s_aggregated.pdf", save_filename, type)), + plot = agg_plt, units = "in", width = fig_width, height = fig_height + ) + } + } + + unagg_plt <- NULL + if (!is.null(param_name)) { + plt_ls <- list() + for (val in unique(agg_rankings[[facet_rows]])) { + plt <- agg_rankings %>% + dplyr::filter(.data[[facet_rows]] == val) %>% + ggplot2::ggplot() + + ggplot2::aes(x = .data[[param_name]], y = avgrank, color = fi) + + ggplot2::geom_boxplot() + + ggplot2::coord_cartesian(ylim = c(ymin, ymax)) + + ggplot2::facet_grid( + reformulate(c(facet_cols, "group"), "fi"), + labeller = ggplot2::label_parsed + ) + + my_theme + + ggplot2::theme( + panel.grid.major = ggplot2::element_line(colour = "#d9d9d9"), + panel.grid.major.x = ggplot2::element_blank(), + panel.grid.minor.x = ggplot2::element_blank(), + axis.line.y = ggplot2::element_blank(), + axis.ticks.y = ggplot2::element_blank(), + legend.position = "none" + ) + + ggplot2::labs(x = param_name, y = sprintf(facet_row_names, val)) + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels) + } + plt_ls[[as.character(val)]] <- plt + } + unagg_plt <- patchwork::wrap_plots(plt_ls) + + patchwork::plot_layout(guides = "collect") + + if (!is.null(save_filename)) { + ggplot2::ggsave( + filename = file.path(save_dir, + sprintf("%s_unaggregated.pdf", save_filename)), + plot = unagg_plt, units = "in", width = fig_width, height = fig_height + ) + } + } + return(list(agg = agg_plt, unagg = unagg_plt)) +} + + +plot_top_stability <- function(results, + group_id = NULL, + top_r = 10, + show_max_features = 5, + varnames = NULL, + base_method = "MDI+_ridge", + return_df = FALSE, + descending_methods = NULL, + manual_color_palette = NULL, + show_methods = NULL, + method_labels = ggplot2::waiver(), + ...) { + + plt_ls <- list() + if (!is.null(show_methods)) { + results <- results %>% + dplyr::filter(fi %in% show_methods) + if (!identical(method_labels, ggplot2::waiver())) { + method_names <- show_methods + names(method_names) <- method_labels + results$fi <- do.call(forcats::fct_recode, + args = c(list(results$fi), as.list(method_names))) + results$fi <- factor(results$fi, levels = method_labels) + method_labels <- ggplot2::waiver() + } + } + + if (!is.null(descending_methods)) { + results <- results %>% + dplyr::mutate( + importance = ifelse(fi %in% descending_methods, -importance, importance) + ) + } + + if (is.null(varnames)) { + varnames <- 0:(length(unique(results$var)) - 1) + } + varnames_df <- data.frame(var = 0:(length(varnames) - 1), Feature = varnames) + + rankings <- results %>% + dplyr::group_by( + rep, fi, dplyr::across(tidyselect::all_of(group_id)) + ) %>% + dplyr::mutate( + rank = rank(-importance), + in_top_r = importance >= sort(importance, decreasing = TRUE)[top_r] + ) %>% + dplyr::ungroup() + + stability_df <- rankings %>% + dplyr::group_by( + fi, var, dplyr::across(tidyselect::all_of(group_id)) + ) %>% + dplyr::summarise( + stability_score = mean(in_top_r), + .groups = "keep" + ) %>% + dplyr::ungroup() + + n_nonzero_stability_df <- stability_df %>% + dplyr::group_by(fi, dplyr::across(tidyselect::all_of(group_id))) %>% + dplyr::summarise( + n_features = sum(stability_score > 0), + .groups = "keep" + ) %>% + dplyr::ungroup() + + if (!is.null(group_id)) { + order_groups <- n_nonzero_stability_df %>% + dplyr::group_by(dplyr::across(tidyselect::all_of(group_id))) %>% + dplyr::summarise( + order = min(n_features) + ) %>% + dplyr::arrange(order) %>% + dplyr::pull(tidyselect::all_of(group_id)) %>% + unique() + + n_nonzero_stability_df <- n_nonzero_stability_df %>% + dplyr::mutate( + dplyr::across( + tidyselect::all_of(group_id), ~factor(.x, levels = order_groups) + ) + ) + + ytext_label_colors <- n_nonzero_stability_df %>% + dplyr::group_by(dplyr::across(tidyselect::all_of(group_id))) %>% + dplyr::summarise( + is_best_method = n_features[fi == base_method] == min(n_features) + ) %>% + dplyr::mutate( + color = ifelse(is_best_method, "black", "#4a86e8") + ) %>% + dplyr::arrange(tidyselect::all_of(group_id)) %>% + dplyr::pull(color) + + plt <- vdocs::plot_horizontal_dotplot( + n_nonzero_stability_df, + x_str = "n_features", y_str = group_id, color_str = "fi", + theme_options = list(size_preset = "xlarge") + ) + + ggplot2::labs( + x = sprintf("Number of Distinct Features in Top 10 Across %s RF Fits", + length(unique(results$rep))), + color = "Method" + ) + + ggplot2::scale_y_discrete(limits = rev) + + ggplot2::theme( + axis.text.y = ggplot2::element_text(color = rev(ytext_label_colors)) + ) + + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + } + + plt_ls[["Summary"]] <- plt + + } else { + plt <- stability_df %>% + dplyr::left_join(y = varnames_df, by = "var") %>% + dplyr::filter(stability_score > 0) %>% + dplyr::left_join(y = n_nonzero_stability_df, by = "fi") %>% + dplyr::mutate(fi = sprintf("%s (# features = %s)", fi, n_features)) %>% + ggplot2::ggplot() + + ggplot2::aes( + x = reorder(Feature, -stability_score), y = stability_score + ) + + ggplot2::facet_wrap(~ fi, scales = "free_x", ncol = 1) + + ggplot2::geom_bar(stat = "identity") + + ggplot2::labs( + x = "Gene", y = sprintf("Proportion of RF Fits in Top %s", top_r) + ) + + vthemes::theme_vmodern(x_text_angle = TRUE, size_preset = "medium") + + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_color_manual(values = manual_color_palette, + labels = method_labels, + guide = ggplot2::guide_legend(reverse = TRUE)) + } + plt_ls[["Non-zero Stability Scores Per Method"]] <- plt + } + + if (is.null(group_id)) { + rankings <- rankings %>% + dplyr::mutate( + .group = "Stability of Top Features" + ) + group_id <- ".group" + } + + for (group in unique(rankings[[group_id]])) { + plt <- rankings %>% + dplyr::filter(.data[[group_id]] == group) %>% + dplyr::group_by(fi, var) %>% + dplyr::summarise( + mean_rank = mean(rank), + median_rank = median(rank), + `SD(Feature Rank)` = sd(rank), + .groups = "keep" + ) %>% + dplyr::group_by(fi) %>% + dplyr::mutate( + agg_feature_rank = rank(mean_rank, ties.method = "random") + ) %>% + dplyr::filter(agg_feature_rank <= show_max_features) %>% + dplyr::mutate(`Average Feature Rank` = as.factor(agg_feature_rank)) %>% + dplyr::left_join(y = varnames_df, by = "var") %>% + dplyr::rename("Method" = "fi") %>% + dplyr::ungroup() %>% + ggplot2::ggplot() + + ggplot2::aes( + x = `Average Feature Rank`, y = `SD(Feature Rank)`, fill = Method, label = Feature + ) + + ggplot2::geom_bar( + stat = "identity", position = "dodge" + ) + + ggplot2::labs(title = group) + + vthemes::scale_fill_vmodern(discrete = TRUE) + + vthemes::theme_vmodern() + + if (!is.null(manual_color_palette)) { + plt <- plt + + ggplot2::scale_fill_manual(values = manual_color_palette, + labels = method_labels) + } + plt_ls[[group]] <- plt + } + + if (return_df) { + return(list(plot_ls = plt_ls, rankings = rankings, stability = stability_df)) + } else { + return(plt_ls) + } +} + + diff --git a/feature_importance/util.py b/feature_importance/util.py new file mode 100644 index 0000000..cb95e03 --- /dev/null +++ b/feature_importance/util.py @@ -0,0 +1,200 @@ +import copy +import os +import warnings +from functools import partial +from os.path import dirname +from os.path import join as oj +from typing import Any, Dict, Tuple + +import numpy as np +import pandas as pd +from sklearn import model_selection +from sklearn.metrics import confusion_matrix, precision_recall_curve, auc, average_precision_score, mean_absolute_error +from sklearn.preprocessing import label_binarize +from sklearn.utils._encode import _unique +from sklearn import metrics +from imodels.importance.ppms import huber_loss + +DATASET_PATH = oj(dirname(os.path.realpath(__file__)), 'data') + + +class ModelConfig: + def __init__(self, + name: str, cls, + vary_param: str = None, vary_param_val: Any = None, + other_params: Dict[str, Any] = {}, + model_type: str = None): + """ + name: str + Name of the model. + vary_param: str + Name of the parameter to be varied + model_type: str + Type of model. ID is used to pair with FIModel. + """ + + self.name = name + self.cls = cls + self.model_type = model_type + self.vary_param = vary_param + self.vary_param_val = vary_param_val + self.kwargs = {} + if self.vary_param is not None: + self.kwargs[self.vary_param] = self.vary_param_val + self.kwargs = {**self.kwargs, **other_params} + + def __repr__(self): + return self.name + + +class FIModelConfig: + def __init__(self, + name: str, cls, ascending = True, + splitting_strategy: str = None, + vary_param: str = None, vary_param_val: Any = None, + other_params: Dict[str, Any] = {}, + model_type: str = None): + """ + ascending: boolean + Whether or not feature importances should be ranked in ascending or + descending order. Default is True, indicating that higher feature + importance score is more important. + splitting_strategy: str + See util.apply_splitting_strategy(). Common inputs are "train-test" and None + vary_param: str + Name of the parameter to be varied + model_type: str + Type of model. ID is used to pair FIModel with Model. + """ + + assert splitting_strategy in { + 'train-test', 'train-tune-test', 'train-test-lowdata', 'train-tune-test-lowdata', + 'train-test-prediction', None + } + + self.name = name + self.cls = cls + self.ascending = ascending + self.model_type = model_type + self.splitting_strategy = splitting_strategy + self.vary_param = vary_param + self.vary_param_val = vary_param_val + self.kwargs = {} + if self.vary_param is not None: + self.kwargs[self.vary_param] = self.vary_param_val + self.kwargs = {**self.kwargs, **other_params} + + def __repr__(self): + return self.name + + +def tp(y_true, y_pred): + conf_mat = confusion_matrix(y_true, y_pred) + return conf_mat[1][1] + + +def fp(y_true, y_pred): + conf_mat = confusion_matrix(y_true, y_pred) + return conf_mat[0][1] + + +def neg(y_true, y_pred): + return sum(y_pred == 0) + + +def pos(y_true, y_pred): + return sum(y_pred == 1) + + +def specificity_score(y_true, y_pred): + conf_mat = confusion_matrix(y_true, y_pred) + return conf_mat[0][0] / (conf_mat[0][0] + conf_mat[0][1]) + + +def auprc_score(y_true, y_score, multi_class="raise"): + assert multi_class in ["raise", "ovr"] + n_classes = len(np.unique(y_true)) + if n_classes <= 2: + precision, recall, _ = precision_recall_curve(y_true, y_score) + return auc(recall, precision) + else: + # ovr is same as multi-label + if multi_class == "raise": + raise ValueError("Must set multi_class='ovr' to evaluate multi-class predictions.") + classes = _unique(y_true) + y_true_multilabel = label_binarize(y_true, classes=classes) + return average_precision_score(y_true_multilabel, y_score) + + +def auroc_score(y_true, y_score, multi_class="raise", **kwargs): + assert multi_class in ["raise", "ovr"] + n_classes = len(np.unique(y_true)) + if n_classes <= 2: + return metrics.roc_auc_score(y_true, y_score, multi_class=multi_class, **kwargs) + else: + # ovr is same as multi-label + if multi_class == "raise": + raise ValueError("Must set multi_class='ovr' to evaluate multi-class predictions.") + classes = _unique(y_true) + y_true_multilabel = label_binarize(y_true, classes=classes) + return metrics.roc_auc_score(y_true_multilabel, y_score, **kwargs) + + +def neg_mean_absolute_error(y_true, y_pred, **kwargs): + return -mean_absolute_error(y_true, y_pred, **kwargs) + + +def neg_huber_loss(y_true, y_pred, **kwargs): + return -huber_loss(y_true, y_pred, **kwargs) + + +def restricted_roc_auc_score(y_true, y_score, ignored_indices=[]): + """ + Compute AUROC score for only a subset of the samples + + :param y_true: + :param y_score: + :param ignored_indices: + :return: + """ + n_examples = len(y_true) + mask = [i for i in range(n_examples) if i not in ignored_indices] + restricted_auc = metrics.roc_auc_score(np.array(y_true)[mask], np.array(y_score)[mask]) + return restricted_auc + + +def compute_nsg_feat_corr_w_sig_subspace(signal_features, nonsignal_features, normalize=True): + + if normalize: + normalized_nsg_features = nonsignal_features / np.linalg.norm(nonsignal_features, axis=0) + else: + normalized_nsg_features = nonsignal_features + + q, r = np.linalg.qr(signal_features) + projections = np.linalg.norm(q.T @ normalized_nsg_features, axis=0) + # nsg_feat_ranked_by_projection = np.argsort(projections) + + return projections + + +def apply_splitting_strategy(X: np.ndarray, + y: np.ndarray, + splitting_strategy: str, + split_seed: str) -> Tuple[Any, Any, Any, Any, Any, Any]: + if splitting_strategy in {'train-test-lowdata', 'train-tune-test-lowdata'}: + test_size = 0.90 # X.shape[0] - X.shape[0] * 0.1 + elif splitting_strategy == "train-test": + test_size = 0.33 + else: + test_size = 0.2 + + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, test_size=test_size, random_state=split_seed) + X_tune = None + y_tune = None + + if splitting_strategy in {'train-tune-test', 'train-tune-test-lowdata'}: + X_train, X_tune, y_train, y_tune = model_selection.train_test_split( + X_train, y_train, test_size=0.2, random_state=split_seed) + + return X_train, X_tune, X_test, y_train, y_tune, y_test