-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
2,022 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from typing import Tuple | ||
|
||
# Based on the architecture from https://github.com/MrGiovanni/ModelsGenesis | ||
torch.use_deterministic_algorithms(False) | ||
|
||
|
||
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): | ||
def _check_input_dim(self, input: torch.Tensor) -> None: | ||
|
||
if input.dim() != 5: | ||
raise ValueError("expected 5D input (got {}D input)".format(input.dim())) | ||
# super(ContBatchNorm3d, self)._check_input_dim(input) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
self._check_input_dim(input) | ||
return F.batch_norm( | ||
input, | ||
self.running_mean, | ||
self.running_var, | ||
self.weight, | ||
self.bias, | ||
True, | ||
self.momentum, | ||
self.eps, | ||
) | ||
|
||
|
||
class LUConv(nn.Module): | ||
def __init__(self, in_chan: int, out_chan: int, act: str): | ||
super(LUConv, self).__init__() | ||
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1) | ||
self.bn1 = ContBatchNorm3d(out_chan) | ||
|
||
if act == "relu": | ||
self.activation = nn.ReLU(out_chan) | ||
elif act == "prelu": | ||
self.activation = nn.PReLU(out_chan) | ||
elif act == "elu": | ||
self.activation = nn.ELU(inplace=True) | ||
else: | ||
raise | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
out = self.activation(self.bn1(self.conv1(x))) | ||
return out | ||
|
||
|
||
def _make_nConv( | ||
in_channel: int, depth: int, act: str, double_chnnel: bool = False | ||
) -> nn.Module: | ||
if double_chnnel: | ||
layer1 = LUConv(in_channel, 32 * (2 ** (depth + 1)), act) | ||
layer2 = LUConv(32 * (2 ** (depth + 1)), 32 * (2 ** (depth + 1)), act) | ||
else: | ||
layer1 = LUConv(in_channel, 32 * (2**depth), act) | ||
layer2 = LUConv(32 * (2**depth), 32 * (2**depth) * 2, act) | ||
|
||
return nn.Sequential(layer1, layer2) | ||
|
||
|
||
# class InputTransition(nn.Module): | ||
# def __init__(self, outChans, elu): | ||
# super(InputTransition, self).__init__() | ||
# self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2) | ||
# self.bn1 = ContBatchNorm3d(16) | ||
# self.relu1 = ELUCons(elu, 16) | ||
# | ||
# def forward(self, x): | ||
# # do we want a PRELU here as well? | ||
# out = self.bn1(self.conv1(x)) | ||
# # split input in to 16 channels | ||
# x16 = torch.cat((x, x, x, x, x, x, x, x, | ||
# x, x, x, x, x, x, x, x), 1) | ||
# out = self.relu1(torch.add(out, x16)) | ||
# return out | ||
|
||
|
||
class DownTransition(nn.Module): | ||
def __init__(self, in_channel: int, depth: int, act: str): | ||
super(DownTransition, self).__init__() | ||
self.ops = _make_nConv(in_channel, depth, act) | ||
self.maxpool = nn.MaxPool3d(2) | ||
self.current_depth = depth | ||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if self.current_depth == 3: | ||
out = self.ops(x) | ||
out_before_pool = out | ||
else: | ||
out_before_pool = self.ops(x) | ||
out = self.maxpool(out_before_pool) | ||
return out, out_before_pool | ||
|
||
|
||
class UpTransition(nn.Module): | ||
def __init__(self, inChans: int, outChans: int, depth: int, act: str): | ||
super(UpTransition, self).__init__() | ||
self.depth = depth | ||
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2) | ||
self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True) | ||
|
||
def forward(self, x: torch.Tensor, skip_x: torch.Tensor) -> torch.Tensor: | ||
out_up_conv = self.up_conv(x) | ||
concat = torch.cat((out_up_conv, skip_x), 1) | ||
out = self.ops(concat) | ||
return out | ||
|
||
|
||
class OutputTransition(nn.Module): | ||
def __init__(self, inChans: int, n_labels: int): | ||
|
||
super(OutputTransition, self).__init__() | ||
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
out = self.sigmoid(self.final_conv(x)) | ||
return out | ||
|
||
|
||
class UNet3D(nn.Module): | ||
# the number of convolutions in each layer corresponds | ||
# to what is in the actual prototxt, not the intent | ||
def __init__(self, n_class: int = 1, act: str = "relu", for_cls: bool = False): | ||
super(UNet3D, self).__init__() | ||
self.for_cls = for_cls | ||
self.down_tr64 = DownTransition(1, 0, act) | ||
self.down_tr128 = DownTransition(64, 1, act) | ||
self.down_tr256 = DownTransition(128, 2, act) | ||
self.down_tr512 = DownTransition(256, 3, act) | ||
if not for_cls: | ||
self.up_tr256 = UpTransition(512, 512, 2, act) | ||
self.up_tr128 = UpTransition(256, 256, 1, act) | ||
self.up_tr64 = UpTransition(128, 128, 0, act) | ||
# self.out_tr = OutputTransition(64, n_class) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
self.out64, self.skip_out64 = self.down_tr64(x) | ||
self.out128, self.skip_out128 = self.down_tr128(self.out64) | ||
self.out256, self.skip_out256 = self.down_tr256(self.out128) | ||
self.out512, self.skip_out512 = self.down_tr512(self.out256) | ||
if self.for_cls: | ||
return self.out512 | ||
|
||
self.out_up_256 = self.up_tr256(self.out512, self.skip_out256) | ||
self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128) | ||
self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64) | ||
# self.out = self.out_tr(self.out_up_64) | ||
|
||
return self.out_up_64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# 3D Medical Imaging Pre-training and Downstream Task Validation | ||
|
||
Self-supervision is employed to learn meaningful representations from unlabeled medical imaging data. By pre-training the model on this vast source of information, we equip it with a strong foundation of understanding the underlying data structure. This leads to significantly improved performance and faster convergence when fine-tuning on downstream tasks like classification and segmentation, compared to training from scratch. | ||
|
||
This example demonstrates how to pre-train a model on 3D MRI medical imaging using self-supervised learning techniques, specifically DINO, on large datasets. The pre-trained model can then be fine-tuned for downstream tasks such as classification and segmentation. The fine-tuning process is adaptable to various medical imaging tasks, even when working with small datasets. | ||
We use the NIH Osteoarthritis Initiative (OAI) dataset for this example which can be downloaded from https://nda.nih.gov/oai | ||
## Data Preparation | ||
|
||
For each training type (self-supervised, classification, segmentation), prepare a CSV file with the following structure: | ||
|
||
### Self-Supervised and Classification CSV | ||
|
||
| PatientID | path | fold | | ||
|-----------|------|------| | ||
| ID1 | /path/to/dicom/folder1 | 0 | | ||
| ID2 | /path/to/dicom/folder2 | 1 | | ||
|
||
For classification, you can add multiple categorical columns to predict and add them as "cls_targets" in the `classification_config.yaml`. | ||
For example, you can classify the disease status (Progression, Non-exposed control group) using the V00COHORT label in the OAI dataset. | ||
|
||
### Segmentation CSV | ||
|
||
| PatientID | img_path | seg_path | fold | | ||
|-----------|----------|----------|------| | ||
| ID1 | /path/to/image1.nii.gz | /path/to/segmentation1.nii.gz | train | | ||
| ID2 | /path/to/image2.nii.gz | /path/to/segmentation2.nii.gz | val | | ||
|
||
For example, you can segment the knee parts using the OAI iMorphics Segmentation. | ||
The "fold" column can be used for cross-validation and can contain any value. The values should be added to the "train/val/test_folds" in the config.yaml files. | ||
|
||
## Configuration | ||
|
||
Each training type has its own `config.yaml` file. Make sure to set the following parameters: | ||
|
||
- `results_dir`: Path to save results and checkpoints | ||
- `csv_path`: Path to the CSV file for the respective training type | ||
- `experiment`: Name of the experiment and also the name of the results folder | ||
- `train_folds`: List of fold values to use for training (e.g., [0, 1, 2]) | ||
- `val_folds`: List of fold values to use for validation (e.g., [3]) | ||
- `test_folds`: List of fold values to use for testing (e.g., [4]) | ||
- `test_ckpt`: Path to the checkpoint for testing. If set to "null", the model will train using the train and validation sets. If a path is provided, it will perform evaluation on the test set using the given checkpoint. | ||
|
||
To load pretrained weights or start from certain checkpoint you need to set only <b>one</b> of the following: | ||
- `suprem_weights`: Path to the backbone pretrained weights from SuPreM (download from https://github.com/MrGiovanni/SuPreM) | ||
- `dino_weights`: Path to the backbone pretrained weights from Dino | ||
- `resume_training_from`: Path to training checkpoint | ||
- `test_ckpt`: If set, the test set as defined in `test_folds` will be evaluated using this checkpoint | ||
If none of them are set you will train from scratch | ||
|
||
Pretrained weights can be downloaded from [SuPreM GitHub repository](https://github.com/MrGiovanni/SuPreM). | ||
|
||
## Training | ||
|
||
The training process involves three main steps: | ||
|
||
1. Self-supervised pre-training with DINO | ||
2. Fine-tuning for classification | ||
3. Fine-tuning for segmentation | ||
|
||
### 1. Self-Supervised Pre-training | ||
|
||
Run DINO pre-training: | ||
|
||
```bash | ||
python fuse_examples/imaging/oai_example/self_supervised/dino.py | ||
``` | ||
|
||
### 2. Classification Fine-tuning | ||
|
||
Set dino_weights in classification_config.yaml to the path of the best DINO checkpoint. | ||
Run classification training: | ||
|
||
```bash | ||
python fuse_examples/imaging/oai_example/downstream/classification.py | ||
``` | ||
|
||
### 3. Segmentation Fine-tuning | ||
|
||
Set dino_weights in segmentation_config.yaml to the same DINO checkpoint path. | ||
Run segmentation training: | ||
```bash | ||
python fuse_examples/imaging/oai_example/downstream/segmentation.py | ||
``` | ||
This process leverages transfer learning, using DINO pre-trained weights to improve performance on downstream tasks. | ||
|
||
### Monitoring Results | ||
|
||
You can track the progress of your training/testing using one of the following methods: | ||
|
||
1. TensorBoard: | ||
To view losses and metrics, run: | ||
``` | ||
tensorboard --logdir=<path_to_experiments_directory> | ||
``` | ||
2. ClearML: | ||
If ClearML is installed and enabled in your config file (`clearml : True`), you can use it to monitor your results. | ||
|
||
Choose the method that best suits your workflow and preferences. |
Empty file.
Oops, something went wrong.