Skip to content

Commit

Permalink
Added support for custom network architectures and using tracking dat…
Browse files Browse the repository at this point in the history
…a in model input. Restructured train config format for more dataset customizability. Updated tracking data normalization.
  • Loading branch information
chriscyyeung committed Nov 27, 2024
1 parent 2d09657 commit 2fc7107
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 73 deletions.
15 changes: 8 additions & 7 deletions UltrasoundSegmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,8 @@ Training hyperparameters can be modified on train_config.yaml.

Similar to running the prepare_data.py script, train.py can be run from command line or by configuring a JSON file.

* `--train-data-folder` should be the path of the folder with the training set (which should be a subset of the output of prepare_data.py)
* `--val-data-folder` should be the path of the folder with validation set
* `--output-dir` is the name of the directory in which to save the run
* `--config-file` is the yaml file detailing the training settings. See [train_config.yaml](train_config.yaml) for an example. Also see [Supported networks](#supported-networks) to see the available networks.
* `--config-file` is the yaml file detailing the training settings. See [train_config.yaml](configs/train_config.yaml) for an example. Also see [Supported networks](#supported-networks) to see the available networks.
* `--save-torchscript` saves the model as a torchscript
* `--save-ckpt-freq` is the integer value for how often (number of epochs) the model saves and is 0 by default
* `--wandb-entity-name` should be set to your username if you are working on a solo project or the username of the owner of a collaborative team on wandb
Expand Down Expand Up @@ -195,9 +193,7 @@ To configure the JSON file for train.py, open the launch.json again. Copy and pa
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": "true",
"args": ["--train-data-folder", "D:/data/train",
"--val-data-folder", "D:/data/val",
"--output-dir", "D:/runs",
"args": ["--output-dir", "D:/runs",
"--save-torchscript",
"--save-log"]
}
Expand Down Expand Up @@ -237,10 +233,15 @@ The network architectures that are currently supported (and their required `name
- UNet++ (`unetplusplus`)
- UNETR (`unetr`)
- SegResNet (`segresnet`)
- Custom network (`custom`)
- nnUNet (`nnunet`)

Using any of the networks other than the nnUNet will use the hyperparameters described in the config file. Otherwise, due to the nature of the nnUNet, many of the hyperparameters will be automatically set and the config file will be ignored. However, using the nnUNet requires additional packages and flags to be set. This will be described in the following section.

### Use a custom network architecture

A custom network architecture (not from monai) can be trained as well, as long as it is a subclass of `torch.nn.Module`. To do this, include the full path of the network's `.py` file along with the name of the model class to the config file. An example is shown in [train_config.yaml](configs/train_config.yaml).

### Using the nnUNet

The implementation of the nnUNet is found on [GitHub](https://github.com/MIC-DKFZ/nnUNet/tree/master). Most of the instructions in this guide can be found in much more detail on the [official repo](https://github.com/MIC-DKFZ/nnUNet/tree/master/documentation). This section will explain what is necessary to use the nnUNet from the aigt repo.
Expand Down Expand Up @@ -281,7 +282,7 @@ All training data from one dataset are located under the `Dataset` folder, which

The images themselves follow this naming convention: `{CASE_IDENTIFIER}_{CHANNEL}.{FILE_ENDING}`. In the above example, `LN003_0000` is the case identifier, `0000` is the 4-digit channel identifier, and `nii.gz` is the file ending. The segmentations follow a similar convention but without the channel indicator. This will be further explained below.

Luckily, assuming your data has already been converted to slices (hopefully using the `--use-file-prefix` flag), `train.py` will convert the data into the nnUNet format for you. However, you must first set the values of 2 dictionaries in the `.yaml` config file. An example is shown in [train_config.yaml](train_config.yaml):
Assuming your data has already been converted to slices (hopefully using the `--use-file-prefix` flag), `train.py` will convert the data into the nnUNet format for you. However, you must first set the values of 2 dictionaries in the `.yaml` config file. An example is shown in [train_config.yaml](configs/train_config.yaml):

```
# For nnUNet only:
Expand Down
87 changes: 87 additions & 0 deletions UltrasoundSegmentation/configs/UsTracking_train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Example config file for train.py

network: custom # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet
loss_function: dicefocal
lambda_ce: 0.9 # 0.0 for dice, 1.0 for cross-entropy/focal
class_weights: [0.2, 0.8] # must be same length as out_channels
image_size: 128
in_channels: &in_c !!int 5
out_channels: !!int 2
num_epochs: !!int 1
batch_size: !!int 128
learning_rate: !!float 0.001
learning_rate_decay_factor: !!float 0.5
learning_rate_decay_frequency: !!int 10
weight_decay: 0.01
shuffle: !!bool true # true/false
seed: !!int 42

dataset:
train_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_train
val_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_val
name: SlidingWindowTrackedUSDataset
params:
window_size: *in_c
gt_idx: 2 # 0: first, 1: middle, 2: last

transforms:
general:
# Basic transforms, do not modify
- name: Transposed
params:
keys: [image, label]
indices: [2, 0, 1]
- name: ToTensord
- name: EnsureTyped
params:
keys: [image, label]
dtype: float32
# Add additional transforms here
- name: Resized
params:
keys: [image, label]
spatial_size: [128, 128]
- name: ScaleIntensityRanged
params:
keys: [image]
a_min: 0.0 # minimum intensity in the original image
a_max: 255.0 # maximum intensity in the original image for 8-bit images
b_min: 0.0 # scaled minimum intensity
b_max: 1.0 # scaled maximum intensity
clip: true
train:
- name: RandGaussianNoised
params:
keys: [image]
prob: 0.5
mean: 0.0
std: 0.1
- name: RandFlipd
params:
keys: [image, label]
prob: 0.5
spatial_axis: [1]
- name: RandAdjustContrastd
params:
keys: [image]
prob: 0.5
gamma: [0.5, 2]
- name: RandAffined
params:
keys: [image, label]
prob: 0.5
spatial_size: [128, 128]
rotate_range: 0.5
shear_range: [0.2, 0.2]
translate_range: [20, 20]
scale_range: [0.2, 0.2]
mode: bilinear
padding_mode: zeros
cache_grid: true

# Custom model parameters
model:
model_path: /home/chrisyeung/spine-segmentation/train_dualnet.py
name: DualEncoderUNet
use_tracking: true
params:
50 changes: 50 additions & 0 deletions UltrasoundSegmentation/configs/train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Example config file for train.py

network: unet # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet, custom
loss_function: DiceCE
lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal
class_weights: [0.25, 0.75] # must be same length as out_channels
image_size: 128
in_channels: !!int 1
out_channels: !!int 2
num_epochs: !!int 100
batch_size: !!int 64
learning_rate: !!float 0.0001
learning_rate_decay_factor: !!float 0.5
learning_rate_decay_frequency: !!int 10
weight_decay: 0.01
shuffle: !!bool true # true/false
seed: !!int 42

dataset:
train_folder: /path/to/train
val_folder: /path/to/val
name: UltrasoundDataset
params: # dataset-specific parameters here

transforms:
general:
# Basic transforms, do not modify
- name: Transposed
params:
keys: [image, label]
indices: [2, 0, 1]
- name: ToTensord
- name: EnsureTyped
params:
keys: [image, label]
dtype: float32

# Add additional transforms here
- name: Resized
params:
keys: [image, label]
spatial_size: [128, 128]
train:

# Custom model parameters
model:
model_path: /path/to/model.py
name: CustomNet # name of the model class
use_tracking: false # use tracking data in model input
params: # model-specific parameters
21 changes: 18 additions & 3 deletions UltrasoundSegmentation/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,19 @@ def __getitem__(self, index):


class SlidingWindowTrackedUSDataset(Dataset):
GT_CHANNEL_IDX_FIRST = 0
GT_CHANNEL_IDX_MIDDLE = 1
GT_CHANNEL_IDX_LAST = 2

def __init__(
self,
root_folder,
imgs_dir="images",
gts_dir="labels",
tfms_dir="transforms",
transform=None,
window_size=4
window_size=5,
gt_idx=GT_CHANNEL_IDX_LAST
):
# get names of subfolders in imgs_dir, gts_dir, and tfms_dir
image_scans = [
Expand Down Expand Up @@ -126,6 +131,16 @@ def __init__(
self.transform = transform
self.window_size = window_size

# which frame to use for ground truth
if gt_idx == self.GT_CHANNEL_IDX_FIRST:
self.gt_idx = 0
elif gt_idx == self.GT_CHANNEL_IDX_MIDDLE:
self.gt_idx = window_size // 2
elif gt_idx == self.GT_CHANNEL_IDX_LAST:
self.gt_idx = window_size - 1
else:
raise ValueError("Invalid gt_idx value. Must be 0, 1, or 2.")

def __len__(self):
return sum(
len(self.data[scan]["image"]) - self.window_size + 1
Expand All @@ -146,8 +161,8 @@ def __getitem__(self, index):
for i in range(self.window_size)
], axis=-1) # shape: (H, W, window_size)

# only take middle frame as label
label = np.load(self.data[scan]["label"][index + self.window_size // 2])
# get gt image
label = np.load(self.data[scan]["label"][index + self.gt_idx])
# If segmentation_data is 2D, add a channel dimension as last dimension
if len(label.shape) == 2:
label = np.expand_dims(label, axis=-1)
Expand Down
Loading

0 comments on commit 2fc7107

Please sign in to comment.