-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for custom network architectures and using tracking dat…
…a in model input. Restructured train config format for more dataset customizability. Updated tracking data normalization.
- Loading branch information
1 parent
2d09657
commit 2fc7107
Showing
9 changed files
with
247 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
UltrasoundSegmentation/configs/UsTracking_train_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Example config file for train.py | ||
|
||
network: custom # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet | ||
loss_function: dicefocal | ||
lambda_ce: 0.9 # 0.0 for dice, 1.0 for cross-entropy/focal | ||
class_weights: [0.2, 0.8] # must be same length as out_channels | ||
image_size: 128 | ||
in_channels: &in_c !!int 5 | ||
out_channels: !!int 2 | ||
num_epochs: !!int 1 | ||
batch_size: !!int 128 | ||
learning_rate: !!float 0.001 | ||
learning_rate_decay_factor: !!float 0.5 | ||
learning_rate_decay_frequency: !!int 10 | ||
weight_decay: 0.01 | ||
shuffle: !!bool true # true/false | ||
seed: !!int 42 | ||
|
||
dataset: | ||
train_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_train | ||
val_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_val | ||
name: SlidingWindowTrackedUSDataset | ||
params: | ||
window_size: *in_c | ||
gt_idx: 2 # 0: first, 1: middle, 2: last | ||
|
||
transforms: | ||
general: | ||
# Basic transforms, do not modify | ||
- name: Transposed | ||
params: | ||
keys: [image, label] | ||
indices: [2, 0, 1] | ||
- name: ToTensord | ||
- name: EnsureTyped | ||
params: | ||
keys: [image, label] | ||
dtype: float32 | ||
# Add additional transforms here | ||
- name: Resized | ||
params: | ||
keys: [image, label] | ||
spatial_size: [128, 128] | ||
- name: ScaleIntensityRanged | ||
params: | ||
keys: [image] | ||
a_min: 0.0 # minimum intensity in the original image | ||
a_max: 255.0 # maximum intensity in the original image for 8-bit images | ||
b_min: 0.0 # scaled minimum intensity | ||
b_max: 1.0 # scaled maximum intensity | ||
clip: true | ||
train: | ||
- name: RandGaussianNoised | ||
params: | ||
keys: [image] | ||
prob: 0.5 | ||
mean: 0.0 | ||
std: 0.1 | ||
- name: RandFlipd | ||
params: | ||
keys: [image, label] | ||
prob: 0.5 | ||
spatial_axis: [1] | ||
- name: RandAdjustContrastd | ||
params: | ||
keys: [image] | ||
prob: 0.5 | ||
gamma: [0.5, 2] | ||
- name: RandAffined | ||
params: | ||
keys: [image, label] | ||
prob: 0.5 | ||
spatial_size: [128, 128] | ||
rotate_range: 0.5 | ||
shear_range: [0.2, 0.2] | ||
translate_range: [20, 20] | ||
scale_range: [0.2, 0.2] | ||
mode: bilinear | ||
padding_mode: zeros | ||
cache_grid: true | ||
|
||
# Custom model parameters | ||
model: | ||
model_path: /home/chrisyeung/spine-segmentation/train_dualnet.py | ||
name: DualEncoderUNet | ||
use_tracking: true | ||
params: |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Example config file for train.py | ||
|
||
network: unet # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet, custom | ||
loss_function: DiceCE | ||
lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal | ||
class_weights: [0.25, 0.75] # must be same length as out_channels | ||
image_size: 128 | ||
in_channels: !!int 1 | ||
out_channels: !!int 2 | ||
num_epochs: !!int 100 | ||
batch_size: !!int 64 | ||
learning_rate: !!float 0.0001 | ||
learning_rate_decay_factor: !!float 0.5 | ||
learning_rate_decay_frequency: !!int 10 | ||
weight_decay: 0.01 | ||
shuffle: !!bool true # true/false | ||
seed: !!int 42 | ||
|
||
dataset: | ||
train_folder: /path/to/train | ||
val_folder: /path/to/val | ||
name: UltrasoundDataset | ||
params: # dataset-specific parameters here | ||
|
||
transforms: | ||
general: | ||
# Basic transforms, do not modify | ||
- name: Transposed | ||
params: | ||
keys: [image, label] | ||
indices: [2, 0, 1] | ||
- name: ToTensord | ||
- name: EnsureTyped | ||
params: | ||
keys: [image, label] | ||
dtype: float32 | ||
|
||
# Add additional transforms here | ||
- name: Resized | ||
params: | ||
keys: [image, label] | ||
spatial_size: [128, 128] | ||
train: | ||
|
||
# Custom model parameters | ||
model: | ||
model_path: /path/to/model.py | ||
name: CustomNet # name of the model class | ||
use_tracking: false # use tracking data in model input | ||
params: # model-specific parameters |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.