SAM-Swin: SAM-Driven Dual-Swin Transformers with Adaptive Lesion Enhancement for Laryngo-Pharyngeal Tumor Detection
This repo is the official implementation of SAM-Swin: SAM-Driven Dual-Swin Transformers with Adaptive Lesion Enhancement for Laryngo-Pharyngeal Tumor Detection.
The SAM-Swin mainly consists of four pivotal components: SAM2-guided lesion location (SAM2-GLLM), whole image branch (WIB), lesion region branch (LRB), and multi-scale lesion-aware enhancement module (MS-LEAM).
- By leveraging the advanced object segmentation capabilities of the SAM2, we pioneerly integrate SAM2 into the SAM-Swin framework, enabling SAM-Swin to achieve highly precise segmentation of the lesion region.
- We propose MS-LAEM designed to adaptively enhance the learning of nuanced complementary features across various scales, improving the quality of feature extraction and representation.
- We introduce the multi-scale CAG loss, a novel approach that employs targeted supervision to facilitate the extraction of class-specific features within the model.
To fine-tune SAM2 tailored for your tasks, we recommend following the guidelines provided in the original repository: MedSAM2
Organize your datasets in the following manner:
datasets/
├── dataset1/
│ ├── global/
│ │ ├── train/
│ │ │ ├── benign/
│ │ │ ├── normal/
│ │ │ └── tumor/
│ │ ├── val/
│ │ │ ├── benign/
│ │ │ ├── normal/
│ │ │ └── tumor/
│ │ └── test/
│ │ ├── benign/
│ │ ├── normal/
│ │ └── tumor/
│ └── local_seg/
│ ├── train/
│ │ ├── benign/
│ │ ├── normal/
│ │ └── tumor/
│ ├── val/
│ │ ├── benign/
│ │ ├── normal/
│ │ └── tumor/
│ └── test/
│ ├── benign/
│ ├── normal/
│ └── tumor/
├── dataset6/
│ └── ...
We train the SAM-Swin in two stages.
-
Stage 1, run:
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/dynamic.yaml --batch-size 32 --pretrained swinv2_base_patch4_window16_256.pth --cache-mode full --amp-opt-level O1 --accumulation-steps 4 --fused_window_process --fused_layernorm --tag exp
-
Stage 2, run:
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/ft_baseline.yaml --batch-size 32 --pretrained <Your path of latest checkpoint at the Stage 1> --cache-mode full --amp-opt-level O1 --accumulation-steps 4 --fused_window_process --fused_layernorm --tag exp_ft
Using DDP, Run:
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/dynamic.yaml --resume <Your path of the checkpoint> --cache-mode full --amp-opt-level O1 --accumulation-steps 4 --fused_window_process --fused_layernorm --tag exp --eval
The code of SAM-Swin is built upon MedSAM2 and Swin Transformer, and we express our gratitude to these awesome projects.
@misc{wei2024samswinsamdrivendualswintransformers,
title={SAM-Swin: SAM-Driven Dual-Swin Transformers with Adaptive Lesion Enhancement for Laryngo-Pharyngeal Tumor Detection},
author={Jia Wei and Yun Li and Xiaomao Fan and Wenjun Ma and Meiyu Qiu and Hongyu Chen and Wenbin Lei},
year={2024},
eprint={2410.21813},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2410.21813},
}