Skip to content

Commit

Permalink
Oai example (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhazan authored Aug 16, 2024
1 parent 594252e commit 66ffe09
Show file tree
Hide file tree
Showing 14 changed files with 2,022 additions and 1 deletion.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ $ pip install fuse-med-ml[all,examples]
# Examples

* Easy access "Hello World" [colab notebook](https://colab.research.google.com/github/BiomedSciAI/fuse-med-ml/blob/master/fuse_examples/imaging/hello_world/hello_world.ipynb)
* classification
* Classification
* [**MNIST**](./fuse_examples/imaging/classification/mnist/) - a simple example, including training, inference and evaluation over [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
* [**STOIC**](./fuse_examples/imaging/classification/stoic21/) - severe COVID-19 classifier baseline given a Computed-Tomography (CT), age group and gender. [Challenge description](https://stoic2021.grand-challenge.org/)

Expand All @@ -232,6 +232,9 @@ $ pip install fuse-med-ml[all,examples]
* [**Skin Lesion**](./fuse_examples/imaging/classification/isic/) - skin lesion classification , including training, inference and evaluation over the public dataset introduced in [ISIC challenge](https://challenge.isic-archive.com/landing/2019)
* [**Breast Cancer Lesion Classification**](./fuse_examples/imaging/classification/cmmd) - lesions classification of tumor ( benign, malignant) in breast mammography over the public dataset introduced in [The Chinese Mammography Database (CMMD)](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70230508)
* [**Mortality prediction for ICU patients**](./fuse_examples/multimodality/ehr_transformer) - Example of EHR transformer applied to the data of Intensive Care Units patients for in-hospital mortality prediction. The dataset is from [PhysioNet Computing in Cardiology Challenge (2012)](https://physionet.org/content/challenge-2012/1.0.0/)
* Pre-training
* [**Medical Imaging Pre-training and Downstream Task Validation**](./fuse_examples/imaging/oai_example) - pre-training a model on 3D MRI medical imaging and then using it for classification and segmentation downstream tasks.


## Walkthrough template
* [**Walkthrough Template**](./fuse/dl/templates/walkthrough_template.py) - includes several TODO notes, marking the minimal scope of code required to get your pipeline up and running. The template also includes useful explanations and tips.
Expand Down
153 changes: 153 additions & 0 deletions fuse/dl/models/backbones/backbone_unet3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

# Based on the architecture from https://github.com/MrGiovanni/ModelsGenesis
torch.use_deterministic_algorithms(False)


class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input: torch.Tensor) -> None:

if input.dim() != 5:
raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
# super(ContBatchNorm3d, self)._check_input_dim(input)

def forward(self, input: torch.Tensor) -> torch.Tensor:
self._check_input_dim(input)
return F.batch_norm(
input,
self.running_mean,
self.running_var,
self.weight,
self.bias,
True,
self.momentum,
self.eps,
)


class LUConv(nn.Module):
def __init__(self, in_chan: int, out_chan: int, act: str):
super(LUConv, self).__init__()
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
self.bn1 = ContBatchNorm3d(out_chan)

if act == "relu":
self.activation = nn.ReLU(out_chan)
elif act == "prelu":
self.activation = nn.PReLU(out_chan)
elif act == "elu":
self.activation = nn.ELU(inplace=True)
else:
raise

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.activation(self.bn1(self.conv1(x)))
return out


def _make_nConv(
in_channel: int, depth: int, act: str, double_chnnel: bool = False
) -> nn.Module:
if double_chnnel:
layer1 = LUConv(in_channel, 32 * (2 ** (depth + 1)), act)
layer2 = LUConv(32 * (2 ** (depth + 1)), 32 * (2 ** (depth + 1)), act)
else:
layer1 = LUConv(in_channel, 32 * (2**depth), act)
layer2 = LUConv(32 * (2**depth), 32 * (2**depth) * 2, act)

return nn.Sequential(layer1, layer2)


# class InputTransition(nn.Module):
# def __init__(self, outChans, elu):
# super(InputTransition, self).__init__()
# self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
# self.bn1 = ContBatchNorm3d(16)
# self.relu1 = ELUCons(elu, 16)
#
# def forward(self, x):
# # do we want a PRELU here as well?
# out = self.bn1(self.conv1(x))
# # split input in to 16 channels
# x16 = torch.cat((x, x, x, x, x, x, x, x,
# x, x, x, x, x, x, x, x), 1)
# out = self.relu1(torch.add(out, x16))
# return out


class DownTransition(nn.Module):
def __init__(self, in_channel: int, depth: int, act: str):
super(DownTransition, self).__init__()
self.ops = _make_nConv(in_channel, depth, act)
self.maxpool = nn.MaxPool3d(2)
self.current_depth = depth

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.current_depth == 3:
out = self.ops(x)
out_before_pool = out
else:
out_before_pool = self.ops(x)
out = self.maxpool(out_before_pool)
return out, out_before_pool


class UpTransition(nn.Module):
def __init__(self, inChans: int, outChans: int, depth: int, act: str):
super(UpTransition, self).__init__()
self.depth = depth
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True)

def forward(self, x: torch.Tensor, skip_x: torch.Tensor) -> torch.Tensor:
out_up_conv = self.up_conv(x)
concat = torch.cat((out_up_conv, skip_x), 1)
out = self.ops(concat)
return out


class OutputTransition(nn.Module):
def __init__(self, inChans: int, n_labels: int):

super(OutputTransition, self).__init__()
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
self.sigmoid = nn.Sigmoid()

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.sigmoid(self.final_conv(x))
return out


class UNet3D(nn.Module):
# the number of convolutions in each layer corresponds
# to what is in the actual prototxt, not the intent
def __init__(self, n_class: int = 1, act: str = "relu", for_cls: bool = False):
super(UNet3D, self).__init__()
self.for_cls = for_cls
self.down_tr64 = DownTransition(1, 0, act)
self.down_tr128 = DownTransition(64, 1, act)
self.down_tr256 = DownTransition(128, 2, act)
self.down_tr512 = DownTransition(256, 3, act)
if not for_cls:
self.up_tr256 = UpTransition(512, 512, 2, act)
self.up_tr128 = UpTransition(256, 256, 1, act)
self.up_tr64 = UpTransition(128, 128, 0, act)
# self.out_tr = OutputTransition(64, n_class)

def forward(self, x: torch.Tensor) -> torch.Tensor:
self.out64, self.skip_out64 = self.down_tr64(x)
self.out128, self.skip_out128 = self.down_tr128(self.out64)
self.out256, self.skip_out256 = self.down_tr256(self.out128)
self.out512, self.skip_out512 = self.down_tr512(self.out256)
if self.for_cls:
return self.out512

self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
# self.out = self.out_tr(self.out_up_64)

return self.out_up_64
98 changes: 98 additions & 0 deletions fuse_examples/imaging/oai_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# 3D Medical Imaging Pre-training and Downstream Task Validation

Self-supervision is employed to learn meaningful representations from unlabeled medical imaging data. By pre-training the model on this vast source of information, we equip it with a strong foundation of understanding the underlying data structure. This leads to significantly improved performance and faster convergence when fine-tuning on downstream tasks like classification and segmentation, compared to training from scratch.

This example demonstrates how to pre-train a model on 3D MRI medical imaging using self-supervised learning techniques, specifically DINO, on large datasets. The pre-trained model can then be fine-tuned for downstream tasks such as classification and segmentation. The fine-tuning process is adaptable to various medical imaging tasks, even when working with small datasets.
We use the NIH Osteoarthritis Initiative (OAI) dataset for this example which can be downloaded from https://nda.nih.gov/oai
## Data Preparation

For each training type (self-supervised, classification, segmentation), prepare a CSV file with the following structure:

### Self-Supervised and Classification CSV

| PatientID | path | fold |
|-----------|------|------|
| ID1 | /path/to/dicom/folder1 | 0 |
| ID2 | /path/to/dicom/folder2 | 1 |

For classification, you can add multiple categorical columns to predict and add them as "cls_targets" in the `classification_config.yaml`.
For example, you can classify the disease status (Progression, Non-exposed control group) using the V00COHORT label in the OAI dataset.

### Segmentation CSV

| PatientID | img_path | seg_path | fold |
|-----------|----------|----------|------|
| ID1 | /path/to/image1.nii.gz | /path/to/segmentation1.nii.gz | train |
| ID2 | /path/to/image2.nii.gz | /path/to/segmentation2.nii.gz | val |

For example, you can segment the knee parts using the OAI iMorphics Segmentation.
The "fold" column can be used for cross-validation and can contain any value. The values should be added to the "train/val/test_folds" in the config.yaml files.

## Configuration

Each training type has its own `config.yaml` file. Make sure to set the following parameters:

- `results_dir`: Path to save results and checkpoints
- `csv_path`: Path to the CSV file for the respective training type
- `experiment`: Name of the experiment and also the name of the results folder
- `train_folds`: List of fold values to use for training (e.g., [0, 1, 2])
- `val_folds`: List of fold values to use for validation (e.g., [3])
- `test_folds`: List of fold values to use for testing (e.g., [4])
- `test_ckpt`: Path to the checkpoint for testing. If set to "null", the model will train using the train and validation sets. If a path is provided, it will perform evaluation on the test set using the given checkpoint.

To load pretrained weights or start from certain checkpoint you need to set only <b>one</b> of the following:
- `suprem_weights`: Path to the backbone pretrained weights from SuPreM (download from https://github.com/MrGiovanni/SuPreM)
- `dino_weights`: Path to the backbone pretrained weights from Dino
- `resume_training_from`: Path to training checkpoint
- `test_ckpt`: If set, the test set as defined in `test_folds` will be evaluated using this checkpoint
If none of them are set you will train from scratch

Pretrained weights can be downloaded from [SuPreM GitHub repository](https://github.com/MrGiovanni/SuPreM).

## Training

The training process involves three main steps:

1. Self-supervised pre-training with DINO
2. Fine-tuning for classification
3. Fine-tuning for segmentation

### 1. Self-Supervised Pre-training

Run DINO pre-training:

```bash
python fuse_examples/imaging/oai_example/self_supervised/dino.py
```

### 2. Classification Fine-tuning

Set dino_weights in classification_config.yaml to the path of the best DINO checkpoint.
Run classification training:

```bash
python fuse_examples/imaging/oai_example/downstream/classification.py
```

### 3. Segmentation Fine-tuning

Set dino_weights in segmentation_config.yaml to the same DINO checkpoint path.
Run segmentation training:
```bash
python fuse_examples/imaging/oai_example/downstream/segmentation.py
```
This process leverages transfer learning, using DINO pre-trained weights to improve performance on downstream tasks.

### Monitoring Results

You can track the progress of your training/testing using one of the following methods:

1. TensorBoard:
To view losses and metrics, run:
```
tensorboard --logdir=<path_to_experiments_directory>
```
2. ClearML:
If ClearML is installed and enabled in your config file (`clearml : True`), you can use it to monitor your results.

Choose the method that best suits your workflow and preferences.
Empty file.
Loading

0 comments on commit 66ffe09

Please sign in to comment.