PyTorch implementation and results for IoU-based positives, a novel strategy for generating positive image crops for visual self-supervised learning (full text).
The method can replace standard random cropping for self-supervised methods such as SimCLR, DINO and iBOT. When compared to plain random cropping on pretraining with 2D CT medical images, IoU-based positives yields better results where fine-tuning on a downstream organ segmentation task.
It's been noticed in a few recent papers (see References) that plain random cropping for generating positive pairs might be suboptimal due to potential significant sementical misalignment. While the misalignment might not be well visible on images from ImageNet, it's much clearer when considering medical images:
Contrary to previous works presenting alternative strategies for generating positive crops for medical data (see References), IoU-based positives does not require any metadata generalizes to both 2D and 3D data, uses no external models and works seamlessly with heterogeneous datasets.
The proposed method works as follows. Given a 2D or 3D input image
The values of
IoU interval | Pretraining 1 | Pretraining 2 | Mean DSC | DSC 1 | DSC 2 | DSC 3 | DSC 4 |
---|---|---|---|---|---|---|---|
- | - | - | 0.843 ± 0.004 | 0.842 | 0.848 | 0.846 | 0.836 |
[0, 0] | link | link | 0.838 ± 0.007 | 0.836 | 0.831 | 0.834 | 0.850 |
[0.0001, 0.3] | link | link | 0.839 ± 0.012 | 0.847 | 0.831 | 0.855 | 0.824 |
[0.3, 0.6] | link | link | 0.853 ± 0.002 | 0.854 | 0.851 | 0.851 | 0.857 |
[0.6, 1] | link | link | 0.843 ± 0.013 | 0.850 | 0.855 | 0.848 | 0.822 |
Random crop. | link | link | 0.824 ± 0.011 | 0.815 | 0.841 | 0.828 | 0.812 |
FLARE 2022 challenge data was used both for pretraining and fine-tuning. The training set includes 50 CT scans with voxel-level labels of 13 abdominal organs and 2,000 unlabeled CT scans. The validation set includes 50 visible unlabeled cases. The testing set includes 200 hidden cases. The present work utilized the challenge’s training set only. For pretraining of the backbone, all the 2,000 unlabeled CT scans were used without any metadata. For fine-tuning, 50 labeled cases were utilized. To evaluate the IoU intervals thoroughly, a majority of cases were assigned to a validation subset: 35 cases were included in the fine-tuning validation subset and 15 cases were included in the fine-tuning training subset.
2D CT slices were used for pretraining and fine-tuning instead of 3D due to computation costs.
See requirements.txt.
-
Preprocess CTs using
preprocess_flare_labelled.py
andpreprocess_flare_unlabelled.py
(default args values were used for the experiments). This is to extract 2D .png from 3D .nii.gz files + there's no need to repeat the same processing each time image is loaded during training. -
Run pretraining(s).
python main_simsiam.py --data_dir <DIR_WITH_PREPROCESSED_2D_PNGS> --embedding_size 48 --batch_size 128 --n_epochs 100 --base_lr 0.025 --min_iou <I_MIN> --max_iou <I_MAX> --num_workers <NUM_WORKERS> --use_amp
- Run fine-tuning(s).
python main_finetune.py --data_dir <DIR_WITH_PREPROCESSED_2D_PNGS> --chkpt_path <PATH_TO_PRETRAINED_CHKPT> --embedding_size 48 --batch_size 32 --n_epochs 225 --patience 20 --sw_batch_size 64 --ignore_user_warning --num_workers <NUM_WORKERS> --use_amp
Arguments used for running each experiment can be found in the corresponding wandb runs (see the table in Results: Files -> config.yaml.
There's an analogous script main_dino.py
for pretraining with DINO.
One can also run all the scripts for 3D data (using --spatial_dims 3
).
Be careful though! At the moment, I can't guarantee that it will work and your PC might blow
up. ;)
[1] Senthil Purushwalkam and Abhinav Gupta. “Demystifying Contrastive Self-Supervised Learning: Invariances, Augmentations and Dataset Biases”. Advances in Neural Information Processing Systems, 2020.
[2] Xiangyu Peng et al. “Crafting Better Contrastive Views for Siamese Representation Learning”. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
[3] Shekoofeh Azizi et al. “Big Self-Supervised Models Advance Medical Image Classification”. IEEE/CVF International Conference on Computer Vision (ICCV), 2021.
[4] Yen Nhi Truong Vu et al. “MedAug: Contrastive learning leveraging patient metadata improves representations for chest X-ray interpretation”. Proceedings of the 6th Machine Learning for Healthcare Conference, 2021.
[5] Dewen Zeng et al. “Positional Contrastive Learning for Volumetric Medical Image Segmentation”. Medical Image Computing and Computer Assisted Intervention – MIC- CAI, 2021.
[6] Yankai Jiang et al. Anatomical Invariance Modeling and Semantic Alignment for Self-supervised Learning in 3D Medical Image Segmentation. arXiv: 2302.05615 [cs.CV], 2023.