Skip to content

Official repository of ICML 2023 paper: Dividing and Conquering a BlackBox to a Mixture of Interpretable Models: Route, Interpret, Repeat

License

Notifications You must be signed in to change notification settings

batmanlab/ICML-2023-Route-interpret-repeat

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

87 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Dividing and Conquering a BlackBox to a Mixture of Interpretable Models: Route, Interpret, Repeat

Official PyTorch implementation of the paper
Dividing and Conquering a BlackBox to a Mixture of Interpretable Models: Route, Interpret, Repeat
Shantanu Ghosh1, Ke Yu2, Forough Arabshahi3, Kayhan Batmanghelich1
1BU ECE, 2Pitt ISP, 3META AI
In ICML, 2023

Table of Contents

  1. Objective
  2. Environment setup
  3. Downloading data
  4. Data preprocessing
  5. Training pipeline
  6. Generated Local Explanations
  7. Suggestions
  8. Checkpoints
  9. How to Cite
  10. License and copyright
  11. Contact

Objective

In this paper, we blur the dichotomy of explaining a Blackbox post-hoc and building inherently interpretable by design models. Beginning with a Blackbox, we iteratively carve out a mixture of interpretable experts (MoIE) and a * residual network*. Each interpretable model specializes in a subset of samples and explains them using First Order Logic (FOL). We route the remaining samples through a flexible residual. We repeat the method on the residual network until all the interpretable models explain the desired proportion of data. Thus, illustration of our method is summarized below:


Refer below for the example of the explanations generated by MoIE:

Environment setup

conda env create --name python_3_7_rtx_6000 -f environment.yml
conda activate python_3_7_rtx_6000

Downloading data

After downloading data from the below links, search for --data-root variable in the codebase and replace the appropriate paths for all the different datasets. Also search for /ocean/projects/asc170022p/shg121/PhD/ICLR-2022 and replace with appropriate paths.

(a) Downloading vision and skin data

Dataset Description URL
CUB-200 Bird Classification dataset CUB-200 Official
Derm7pt Dermatology Concepts Dataset Get access here
HAM10k Skin lesion classification dataset Kaggle Link
SIIM_ISIC Skin Melanoma classification SIIM-ISIC Kaggle
Awa2 Animals with Attributes2 Awa2 official

(b) Downloading MIMIC-CXR

For more details please follow the AGXNet Repository.

Data preprocessing

(a) Preprocessing CUB200

To get the CUB200 metadata and dataset splits follow Logic Explained network. Once the json files are downloaded, search for --json-root variable in the codebase and replace the appropriate paths for all the different datasets.

To preprocess the concepts for CUB200, follow:

python ./src/codebase/data_preprocessing/download_cub.py

(b) Preprocessing MIMIC-CXR

To preprocess MIMIC-CXR for Effusion, follow the following steps sequentially:

  1. To generate itemized RadGraph examples, run:
python ./src/codebase/data_preprocessing/mimic-cxr/miccai-main/preprocessing/radgraph_itemized.py
  1. To parse RadGraph relations, run:
python ./src/codebase/data_preprocessing/mimic-cxr/miccai-main/preprocessing/radgraph_parsed.py
  1. To create adjacency matrix that represents the relations between anatomical landmarks and observations mentioned in radiology reports, run:
python ./src/codebase/data_preprocessing/mimic-cxr/miccai-main/preprocessing/adj_matrix.py

Step 3 will be the concepts for training MoIE-CXR. Also, remove the disease label to be classified from the concepts. For example, to classify Pneumonia (disease label), Pneumonia will show up in the concepts produced in Step 3 as we are extracting anatomies and observations both using Rad-graph. This is redundant. So in this case, manually remove Pneumonia from the concepts.

If you don't want to run the pre-processing steps for MIMIC-CXR for the radgraph files to get the concepts, please refer to the following paths directly (these will be the outputs of the above 3 steps) and place the files in respective folders as indicated in code for training the blackbox, concept predictor (t) and experts in MoIE:

Variable Description Paths
--radgraph-adj-mtx-pickle-file radgraph adjacent matrix landmark - observation landmark_observation_adj_mtx_v2.pickle
--radgraph-sids-npy-file radgraph study ids landmark_observation_sids_v2.npy
--radgraph-adj-mtx-npy-file radgraph adjacent matrix landmark - observation landmark_observation_adj_mtx_v2.npy
--nvidia-bounding-box-file bounding boxes annotated for pneumonia and pneumothorax mimic-cxr-annotation.csv
--imagenome-radgraph-landmark-mapping-file Landmark mapping between ImaGenome and RadGraph landmark_mapping.json

Training pipeline

For training MIMIC-CXR please follow our MICCAI 2023 repository to cater the high imbalance.

All the scripts for training MoIE, is included in ./src/scripts folder for all the datasets and architectures with comments. Follow every command sequentially of each script to train/test the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r).

  • As a first step find and replace the project path /ocean/projects/asc170022p/shg121/PhD/ICLR-2022 from the whole codebase with appropriate path.

  • Also, after training and testing MoIE, as the last step in each script, FOLs_vision_main.py file is responsible for generating instance specific FOL. This file uses ./src/codebase/Completeness_and_interventions/paths_MoIE.json file where we keep all the paths and filenames of the checkpoints of Blackbox (bb), concept predictor (t), explainer ( g), and residual (r). Replace those paths and filenames with the appropriate ones based on the experiments. Refer below for the description of the variables paths_MoIE.json:

Variable Description
cub_ViT-B_16 Root variable for CUB200 dataset with Vision Transformer as the Blackbox (BB)
cub_ResNet101 Root variable for CUB200 dataset with Resnet101 as the Blackbox (BB)
HAM10k_Inception_V3 Root variable for HAM10k dataset with Inception_v3 as the Blackbox (BB)
SIIM-ISIC_Inception_V3 Root variable for SIIM-ISIC dataset with Inception_v3 as the Blackbox (BB)
awa2_ViT-B_16 Root variable for Awa2 dataset with Vision Transformer as the Blackbox (BB)
awa2_ResNet101 Root variable for Awa2 dataset with Resnet101 as the Blackbox (BB)
  • Note the root follow dataset_BB_architecture format. Do not modify this format. For each of the above roots paths_MoIE.json file, based on the dataset and architectures, edit the values in MoIE_paths, t , bb with appropriate checkpoint paths and files for the different experts (g), concept predictors (t) and Blackbox ( bb).

  • Similarly, edit the checkpoint paths and files of ./src/codebase/MIMIC_CXR/paths_mimic_cxr_icml.json for effusion of MIMIC-CXR.

Refer to the following sections for details of each of the scripts.

(a) Running MoIE

Script name Description Comment
./src/scripts/cub_resnet.sh Script for CUB200 dataset with Resnet101 as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/cub_vit.sh Script for CUB200 dataset with Vision Transformer as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/awa2_resnet.sh Script for Awa2 dataset with Resnet101 as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/awa2_vit.sh Script for Awa2 dataset with Vision Transformer as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/ham_10k.sh Script for HAM10k dataset with Inception_v3 as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/SIIM-ISIC.sh Script for SIIM-ISIC dataset with Inception_v3 as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)
./src/scripts/mimic_effusion.sh Script for MIMIC-CXR dataset with Densenet121 as the Blackbox (BB) Included train/test and FOL generation script for the Blackbox (BB), concept predictor (t), explainers (g) and residuals (r)

For reference, check the following repositories for SOTA Blackboxes and concepts:

(b) Compute the performance metrics

To compute, the performance metrics (accuracy/AUROC) for all the experts cumulatively (Table 2 in the paper), please refer below for the ipython notebooks.

Notebook Description
./src/codebase/iPython/Cumulative_performance/CUB-Resnet.ipynb Script for CUB200 dataset with Resnet101 as the Blackbox (BB)
./src/codebase/iPython/Cumulative_performance/CUB-VIT.ipynb Script for CUB200 dataset with Vision Transformer as the Blackbox (BB)
./src/codebase/iPython/Cumulative_performance/AWA2-Resnet.ipynb Script for Awa2 dataset with Resnet101 as the Blackbox (BB)
./src/codebase/iPython/Cumulative_performance/AWA2-VIT.ipynb Script for Awa2 dataset with Vision Transformer as the Blackbox (BB)
./src/codebase/iPython/Cumulative_performance/HAM10k.ipynb Script for HAM10k dataset with Inception_v3 as the Blackbox (BB)
./src/codebase/iPython/Cumulative_performance/ISIC.ipynb Script for SIIM-ISIC dataset with Inception_v3 as the Blackbox (BB)

For effusion in MIMIC-CXR, the command to estimate the AUROC of all the experts is:

python ./src/codebase/performance_calculation_mimic_cxr_main.py --iterations 3 --icml "y" --disease "effusion" --model "MoIE"

This command is already included in the file ./src/scripts/mimic_effusion.sh.

(c) Validating the concept importance

In the paper, we validate in the importance of the extracted concepts using three experiments:

  1. Zeroing out the important concepts
  2. Computing the completeness scores of the important concept
    • Before running the script for completeness score, run the following scripts to create the dataset to train the projection model in completeness score paper:
Notebook Description
./src/codebase/iPython/Completeness_dataset/CUB_Resnet.ipynb Script for CUB200 dataset with Resnet101 as the Blackbox (BB)
./src/codebase/iPython/Completeness_dataset/CUB_VIT.ipynb Script for CUB200 dataset with Vision Transformer as the Blackbox (BB)
./src/codebase/iPython/Completeness_dataset/Awa2_Resnet.ipynb Script for Awa2 dataset with Resnet101 as the Blackbox (BB)
./src/codebase/iPython/Completeness_dataset/Awa2_VIT.ipynb Script for Awa2 dataset with Vision Transformer as the Blackbox (BB)
./src/codebase/iPython/Completeness_dataset/HAM10k.ipynb Script for HAM10k dataset with Inception_v3 as the Blackbox (BB)
  1. Performing test time interventions

Please refer to the table below for the scripts to replicate the above experiments (zeroing out the concepts, completeness scores and test time interventions):

Scripts Description
./src/scripts/zero_out_concepts.sh Script to zero out the important concepts
./src/scripts/completeness_scores.sh Script to estimate the completeness scores of the important concepts
./src/scripts/tti.sh Script to perform test time interventions for the important concepts
./src/codebase/tti_experts.sh Script to perform test time interventions for the important concepts corresponding to only the harder samples covered by the last two experts

Generated Local Explanations

We have included the instance-specific explanations per expert for each dataset in the folder ./explanations.

Suggestions

Most of the argparse variables are self-explanatory. However, in order to perform the experiments successfully, give the correct paths and files to the following variables in train_explainer_<dataset>.py and test_explainer_<dataset>.py.

  • For train_explainer_<dataset>.py (ex. train_explainer_CUB.py , train_explainer_ham10k.py etc.), follow the rules:

    1. --checkpoint-model : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the checkpoint files of all the experts of previous iterations while training for the expert ( g) (--expert-to-train "explainer"). For example: if the current iteration is 3, include the checkpoint files for the expert 1 and expert 2 sequentially. While training the residual (--expert-to-train "residual"), include the checkpoint files of all the experts including the current iteration.
    2. --checkpoint-residual : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the checkpoint files of all the residuals of previous iterations while training the expert ( g) (--expert-to-train "explainer") and the residual (--expert-to-train "explainer"). For example: if the current iteration is 3, include the checkpoint files for the residual 1 and residual 2 sequentially.
    3. --prev_explainer_chk_pt_folder : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the folders of the checkpoint files of all the experts of previous iterations. For example: if the current iteration is 3, include the checkpoint folders for the expert 1 and expert 2 sequentially. For all the datasets other than MIMIC-CXR, include the absolute path. For MIMIC-CXR, only include the experiment folder where the checkpoint file will be stored.

    Refer to the following example command for the 3rd iteration for CUB200 dataset with VIT as the blackbox to train the expert:

    python ./src/codebase/train_explainer_CUB.py --expert-to-train "explainer" --checkpoint-model checkpt_expert1 checkpt_expert2 --checkpoint-residual checkpt_residual1 checkpt_residual2 --prev_explainer_chk_pt_folder checkpt_folder_exper1 checkpt_folder_expert2 --root-bb "lr_0.03_epochs_95" --checkpoint-bb "VIT_CUBS_8000_checkpoint.bin" --iter 3 --dataset "cub" --cov cov_iter1 cov_iter2 cov_iter3 --bs 16 --dataset-folder-concepts "lr_0.03_epochs_95_ViT-B_16_layer4_VIT_sgd_BCE" --lr learning_rate_iter1 learning_rate_iter2 learning_rate_iter3 --input-size-pi 2048 --temperature-lens 0.7 --lambda-lens 0.0001 --alpha-KD 0.9 --temperature-KD 10 --hidden-nodes 10 --layer "VIT" --arch "VIT-B_16" 

    Similarly, refer to the following example command for the 3rd iteration for CUB200 dataset with VIT as the blackbox to train the residual:

    python ./src/codebase/train_explainer_CUB.py --expert-to-train "residual" --checkpoint-model checkpt_expert1 checkpt_expert2 checkpt_expert3 --checkpoint-residual checkpt_residual1 checkpt_residual2 --prev_explainer_chk_pt_folder checkpt_folder_exper1 checkpt_folder_expert2 --root-bb "lr_0.03_epochs_95" --checkpoint-bb "VIT_CUBS_8000_checkpoint.bin" --iter 3 --dataset "cub" --cov cov_iter1 cov_iter2 cov_iter3 --bs 16 --dataset-folder-concepts "lr_0.03_epochs_95_ViT-B_16_layer4_VIT_sgd_BCE" --lr learning_rate_iter1 learning_rate_iter2 learning_rate_iter3 --input-size-pi 2048 --temperature-lens 0.7 --lambda-lens 0.0001 --alpha-KD 0.9 --temperature-KD 10 --hidden-nodes 10 --layer "VIT" --arch "VIT-B_16"  
  • For test_explainer_<dataset>.py (ex. test_explainer_CUB.py , test_explainer_ham10k.py etc.), follow the rules:

    1. --checkpoint-model : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the checkpoint files of all the experts including the current iteration while testing the expert ( g) (--expert-to-train "explainer") and the residual (--expert-to-train "explainer").
    2. --checkpoint-residual : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the checkpoint files of all the residuals of previous iterations while training for the expert ( g) (--expert-to-train "explainer")**. For example: if the current iteration is 3, include the checkpoint files for the residual 1 and residual 2 sequentially. While testing the residual (--expert-to-train "residual"), include the checkpoint files of all the residuals including the current iteration.
    3. --prev_explainer_chk_pt_folder : Don't include this variable for the 1st iteration. For 2nd iteration and onwards, include the folders of the checkpoint files all the experts of previous iterations. For example: if the current iteration is 3, include the checkpoint folders for the expert 1 and expert 2 sequentially. For all the datasets other than MIMIC-CXR, include the absolute path. For MIMIC-CXR, only include the experiment folder where the checkpoint file will be stored.

    Refer to the following example command for the 3rd iteration for CUB200 dataset with VIT as the blackbox to test the expert:

    python ./src/codebase/test_explainer_CUB.py --expert-to-train "explainer" --checkpoint-model checkpt_expert1 checkpt_expert2 checkpt_expert3 --checkpoint-residual checkpt_residual1 checkpt_residual2 --prev_explainer_chk_pt_folder checkpt_folder_exper1 checkpt_folder_expert2 --root-bb "lr_0.03_epochs_95" --checkpoint-bb "VIT_CUBS_8000_checkpoint.bin" --iter 3 --dataset "cub" --cov cov_iter1 cov_iter2 cov_iter3 --bs 16 --dataset-folder-concepts "lr_0.03_epochs_95_ViT-B_16_layer4_VIT_sgd_BCE" --lr learning_rate_iter1 learning_rate_iter2 learning_rate_iter3 --input-size-pi 2048 --temperature-lens 0.7 --lambda-lens 0.0001 --alpha-KD 0.9 --temperature-KD 10 --hidden-nodes 10 --layer "VIT" --arch "VIT-B_16"  

    Similarly, refer to the following example command for the 3rd iteration for CUB200 dataset with VIT as the blackbox to test the residual:

    python ./src/codebase/test_explainer_CUB.py --expert-to-train "residual" --checkpoint-model checkpt_expert1 checkpt_expert2 checkpt_expert3 --checkpoint-residual checkpt_residual1 checkpt_residual2 checkpt_residual3 --prev_explainer_chk_pt_folder checkpt_folder_exper1 checkpt_folder_expert2 --root-bb "lr_0.03_epochs_95" --checkpoint-bb "VIT_CUBS_8000_checkpoint.bin" --iter 3 --dataset "cub" --cov cov_iter1 cov_iter2 cov_iter3 --bs 16 --dataset-folder-concepts "lr_0.03_epochs_95_ViT-B_16_layer4_VIT_sgd_BCE" --lr learning_rate_iter1 learning_rate_iter2 learning_rate_iter3 --input-size-pi 2048 --temperature-lens 0.7 --lambda-lens 0.0001 --alpha-KD 0.9 --temperature-KD 10 --hidden-nodes 10 --layer "VIT" --arch "VIT-B_16" 

Also make sure the following variables are correct:

  • --cov: Coverages of each iteration separated by a space as in the above commands.
  • --lr: Learning rates of each expert separated by a space as in the above commands.
  • --data-root: Dataset path of images, labels and concepts (if exists)
  • --logs: Path of tensorboard logs

Checkpoints

For the checkpoints of the pretrained blackboxes and concept banks, refer below:

Blackbox Concept predictor (t) / Concept banks
CUB200-VIT CUB200-VIT
HAM10k HAM10k
Effusion-MIMIC-CXR Effusion-MIMIC-CXR
Awa2-VIT Awa2-VIT

Note for HAM10k, we add the extracted concept bank after training t. No need to train t for HAM10k and SIIM-ISIC, if this concept bank is used. For others, the above paths contain the checkpoints of t. Use these checkpoints to extract the concepts.

How to Cite

  • Main ICML 2023 paper
@InProceedings{pmlr-v202-ghosh23c,
  title = 	 {Dividing and Conquering a {B}lack{B}ox to a Mixture of Interpretable Models: Route, Interpret, Repeat},
  author =       {Ghosh, Shantanu and Yu, Ke and Arabshahi, Forough and Batmanghelich, Kayhan},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  pages = 	 {11360--11397},
  year = 	 {2023},
  editor = 	 {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
  volume = 	 {202},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {23--29 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://proceedings.mlr.press/v202/ghosh23c/ghosh23c.pdf},
  url = 	 {https://proceedings.mlr.press/v202/ghosh23c.html},
  abstract = 	 {ML model design either starts with an interpretable model or a Blackbox and explains it post hoc. Blackbox models are flexible but difficult to explain, while interpretable models are inherently explainable. Yet, interpretable models require extensive ML knowledge and tend to be less flexible, potentially underperforming than their Blackbox equivalents. This paper aims to blur the distinction between a post hoc explanation of a Blackbox and constructing interpretable models. Beginning with a Blackbox, we iteratively <em>carve out</em> a mixture of interpretable models and a <em>residual network</em>. The interpretable models identify a subset of samples and explain them using First Order Logic (FOL), providing basic reasoning on concepts from the Blackbox. We route the remaining samples through a flexible residual. We repeat the method on the residual network until all the interpretable models explain the desired proportion of data. Our extensive experiments show that our <em>route, interpret, and repeat</em> approach (1) identifies a richer diverse set of instance-specific concepts with high concept completeness via interpretable models by specializing in various subsets of data without compromising in performance, (2) identifies the relatively “harder” samples to explain via residuals, (3) outperforms the interpretable by-design models by significant margins during test-time interventions, (4) can be used to fix the shortcut learned by the original Blackbox.}
}
  • Shortcut paper published in Workshop on Spurious Correlations, Invariance and Stability, ICML 2023
@inproceedings{ghosh2023tackling,
  title={Tackling Shortcut Learning in Deep Neural Networks: An Iterative Approach with Interpretable Models},
  author={Ghosh, Shantanu and Yu, Ke and Arabshahi, Forough and Batmanghelich, Kayhan},
  booktitle={ICML 2023: Workshop on Spurious Correlations, Invariance and Stability},
  year={2023}
}
  • Transfer learning paper published in Workshop on Interpretable Machine Learning in Healthcare, ICML 2023
@inproceedings{ghosh2023bridging,
  title={Bridging the Gap: From Post Hoc Explanations to Inherently Interpretable Models for Medical Imaging},
  author={Ghosh, Shantanu and Yu, Ke and Arabshahi, Forough and Batmanghelich, Kayhan},
  booktitle={ICML 2023: Workshop on Interpretable Machine Learning in Healthcare},
  year={2023}
}

License and copyright

Licensed under the MIT License

Copyright © Batman Lab, 2023

Contact

For any queries, contact: [email protected]