- 2024/04/06: Our paper has been accpeted in MIDL 2024!🎉🎉
Machine unlearning is a promising paradigm for removing unwanted data samples from a trained model, towards ensuring compliance with privacy regulations and limiting harmful biases. Although unlearning has been shown in, e.g., classification and recommendation systems, its potential in medical image-to-image translation, specifically in image recon-struction, has not been thoroughly investigated. This paper shows that machine unlearning is possible in MRI tasks and has the potential to benefit for bias removal. We set up a protocol to study how much shared knowledge exists between datasets of different organs, allowing us to effectively quantify the effect of unlearning. Our study reveals that combining training data can lead to hallucinations and reduced image quality in the reconstructed data. We use unlearning to remove hallucinations as a proxy exemplar of undesired data removal. Indeed, we show that machine unlearning is possible without full retraining. Furthermore, our observations indicate that maintaining high performance is feasible even when using only a subset of retain data. We have made our code publicly accessible.
# clone project
git clone https://github.com/YuyangXueEd/ReconUnlearning
cd ReconUnlearning
# [OPTIONAL] create conda environment
conda create -n unrecon python=3.10
conda activate unrecon
# install pytorch according to instructions
# https://pytorch.org/get-started/
# install requirements
pip install -r requirements.txt
# clone project
git clone https://github.com/YuyangXueEd/ReconUnlearning
cd ReconUnlearning
# create conda environment and install dependencies
conda env create -f environment.yaml -n unrecon
# activate conda environment
conda activate unrecon
Train model with default configuration
# train on CPU
python src/train.py trainer=cpu
# train on GPU
python src/train.py trainer=gpu
Train model with chosen experiment configuration from configs/experiment/
python src/train.py experiment=experiment_name.yaml
You can override any parameter from command line like this
python src/train.py trainer.max_epochs=20 data.batch_size=64