From 2fc7107aa0193cc991402084685c2f061ef05bd3 Mon Sep 17 00:00:00 2001 From: Chris Yeung Date: Wed, 27 Nov 2024 17:16:21 -0500 Subject: [PATCH] Added support for custom network architectures and using tracking data in model input. Restructured train config format for more dataset customizability. Updated tracking data normalization. --- UltrasoundSegmentation/README.md | 15 +-- .../configs/UsTracking_train_config.yaml | 87 ++++++++++++++ .../{ => configs}/prepare_data_config.yaml | 0 .../configs/train_config.yaml | 50 ++++++++ .../{ => configs}/train_nnunet_config.yaml | 0 UltrasoundSegmentation/datasets.py | 21 +++- UltrasoundSegmentation/train.py | 109 +++++++++++++----- UltrasoundSegmentation/train_config.yaml | 36 ------ UltrasoundSegmentation/train_nnunet.py | 2 +- 9 files changed, 247 insertions(+), 73 deletions(-) create mode 100644 UltrasoundSegmentation/configs/UsTracking_train_config.yaml rename UltrasoundSegmentation/{ => configs}/prepare_data_config.yaml (100%) create mode 100644 UltrasoundSegmentation/configs/train_config.yaml rename UltrasoundSegmentation/{ => configs}/train_nnunet_config.yaml (100%) delete mode 100644 UltrasoundSegmentation/train_config.yaml diff --git a/UltrasoundSegmentation/README.md b/UltrasoundSegmentation/README.md index b5ba826..3bc1f26 100644 --- a/UltrasoundSegmentation/README.md +++ b/UltrasoundSegmentation/README.md @@ -164,10 +164,8 @@ Training hyperparameters can be modified on train_config.yaml. Similar to running the prepare_data.py script, train.py can be run from command line or by configuring a JSON file. -* `--train-data-folder` should be the path of the folder with the training set (which should be a subset of the output of prepare_data.py) -* `--val-data-folder` should be the path of the folder with validation set * `--output-dir` is the name of the directory in which to save the run -* `--config-file` is the yaml file detailing the training settings. See [train_config.yaml](train_config.yaml) for an example. Also see [Supported networks](#supported-networks) to see the available networks. +* `--config-file` is the yaml file detailing the training settings. See [train_config.yaml](configs/train_config.yaml) for an example. Also see [Supported networks](#supported-networks) to see the available networks. * `--save-torchscript` saves the model as a torchscript * `--save-ckpt-freq` is the integer value for how often (number of epochs) the model saves and is 0 by default * `--wandb-entity-name` should be set to your username if you are working on a solo project or the username of the owner of a collaborative team on wandb @@ -195,9 +193,7 @@ To configure the JSON file for train.py, open the launch.json again. Copy and pa "program": "${file}", "console": "integratedTerminal", "justMyCode": "true", - "args": ["--train-data-folder", "D:/data/train", - "--val-data-folder", "D:/data/val", - "--output-dir", "D:/runs", + "args": ["--output-dir", "D:/runs", "--save-torchscript", "--save-log"] } @@ -237,10 +233,15 @@ The network architectures that are currently supported (and their required `name - UNet++ (`unetplusplus`) - UNETR (`unetr`) - SegResNet (`segresnet`) +- Custom network (`custom`) - nnUNet (`nnunet`) Using any of the networks other than the nnUNet will use the hyperparameters described in the config file. Otherwise, due to the nature of the nnUNet, many of the hyperparameters will be automatically set and the config file will be ignored. However, using the nnUNet requires additional packages and flags to be set. This will be described in the following section. +### Use a custom network architecture + +A custom network architecture (not from monai) can be trained as well, as long as it is a subclass of `torch.nn.Module`. To do this, include the full path of the network's `.py` file along with the name of the model class to the config file. An example is shown in [train_config.yaml](configs/train_config.yaml). + ### Using the nnUNet The implementation of the nnUNet is found on [GitHub](https://github.com/MIC-DKFZ/nnUNet/tree/master). Most of the instructions in this guide can be found in much more detail on the [official repo](https://github.com/MIC-DKFZ/nnUNet/tree/master/documentation). This section will explain what is necessary to use the nnUNet from the aigt repo. @@ -281,7 +282,7 @@ All training data from one dataset are located under the `Dataset` folder, which The images themselves follow this naming convention: `{CASE_IDENTIFIER}_{CHANNEL}.{FILE_ENDING}`. In the above example, `LN003_0000` is the case identifier, `0000` is the 4-digit channel identifier, and `nii.gz` is the file ending. The segmentations follow a similar convention but without the channel indicator. This will be further explained below. -Luckily, assuming your data has already been converted to slices (hopefully using the `--use-file-prefix` flag), `train.py` will convert the data into the nnUNet format for you. However, you must first set the values of 2 dictionaries in the `.yaml` config file. An example is shown in [train_config.yaml](train_config.yaml): +Assuming your data has already been converted to slices (hopefully using the `--use-file-prefix` flag), `train.py` will convert the data into the nnUNet format for you. However, you must first set the values of 2 dictionaries in the `.yaml` config file. An example is shown in [train_config.yaml](configs/train_config.yaml): ``` # For nnUNet only: diff --git a/UltrasoundSegmentation/configs/UsTracking_train_config.yaml b/UltrasoundSegmentation/configs/UsTracking_train_config.yaml new file mode 100644 index 0000000..3c58dd6 --- /dev/null +++ b/UltrasoundSegmentation/configs/UsTracking_train_config.yaml @@ -0,0 +1,87 @@ +# Example config file for train.py + +network: custom # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet +loss_function: dicefocal +lambda_ce: 0.9 # 0.0 for dice, 1.0 for cross-entropy/focal +class_weights: [0.2, 0.8] # must be same length as out_channels +image_size: 128 +in_channels: &in_c !!int 5 +out_channels: !!int 2 +num_epochs: !!int 1 +batch_size: !!int 128 +learning_rate: !!float 0.001 +learning_rate_decay_factor: !!float 0.5 +learning_rate_decay_frequency: !!int 10 +weight_decay: 0.01 +shuffle: !!bool true # true/false +seed: !!int 42 + +dataset: + train_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_train + val_folder: /mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_val + name: SlidingWindowTrackedUSDataset + params: + window_size: *in_c + gt_idx: 2 # 0: first, 1: middle, 2: last + +transforms: + general: + # Basic transforms, do not modify + - name: Transposed + params: + keys: [image, label] + indices: [2, 0, 1] + - name: ToTensord + - name: EnsureTyped + params: + keys: [image, label] + dtype: float32 + # Add additional transforms here + - name: Resized + params: + keys: [image, label] + spatial_size: [128, 128] + - name: ScaleIntensityRanged + params: + keys: [image] + a_min: 0.0 # minimum intensity in the original image + a_max: 255.0 # maximum intensity in the original image for 8-bit images + b_min: 0.0 # scaled minimum intensity + b_max: 1.0 # scaled maximum intensity + clip: true + train: + - name: RandGaussianNoised + params: + keys: [image] + prob: 0.5 + mean: 0.0 + std: 0.1 + - name: RandFlipd + params: + keys: [image, label] + prob: 0.5 + spatial_axis: [1] + - name: RandAdjustContrastd + params: + keys: [image] + prob: 0.5 + gamma: [0.5, 2] + - name: RandAffined + params: + keys: [image, label] + prob: 0.5 + spatial_size: [128, 128] + rotate_range: 0.5 + shear_range: [0.2, 0.2] + translate_range: [20, 20] + scale_range: [0.2, 0.2] + mode: bilinear + padding_mode: zeros + cache_grid: true + +# Custom model parameters +model: + model_path: /home/chrisyeung/spine-segmentation/train_dualnet.py + name: DualEncoderUNet + use_tracking: true + params: diff --git a/UltrasoundSegmentation/prepare_data_config.yaml b/UltrasoundSegmentation/configs/prepare_data_config.yaml similarity index 100% rename from UltrasoundSegmentation/prepare_data_config.yaml rename to UltrasoundSegmentation/configs/prepare_data_config.yaml diff --git a/UltrasoundSegmentation/configs/train_config.yaml b/UltrasoundSegmentation/configs/train_config.yaml new file mode 100644 index 0000000..fc034c1 --- /dev/null +++ b/UltrasoundSegmentation/configs/train_config.yaml @@ -0,0 +1,50 @@ +# Example config file for train.py + +network: unet # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet, custom +loss_function: DiceCE +lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal +class_weights: [0.25, 0.75] # must be same length as out_channels +image_size: 128 +in_channels: !!int 1 +out_channels: !!int 2 +num_epochs: !!int 100 +batch_size: !!int 64 +learning_rate: !!float 0.0001 +learning_rate_decay_factor: !!float 0.5 +learning_rate_decay_frequency: !!int 10 +weight_decay: 0.01 +shuffle: !!bool true # true/false +seed: !!int 42 + +dataset: + train_folder: /path/to/train + val_folder: /path/to/val + name: UltrasoundDataset + params: # dataset-specific parameters here + +transforms: + general: + # Basic transforms, do not modify + - name: Transposed + params: + keys: [image, label] + indices: [2, 0, 1] + - name: ToTensord + - name: EnsureTyped + params: + keys: [image, label] + dtype: float32 + + # Add additional transforms here + - name: Resized + params: + keys: [image, label] + spatial_size: [128, 128] + train: + +# Custom model parameters +model: + model_path: /path/to/model.py + name: CustomNet # name of the model class + use_tracking: false # use tracking data in model input + params: # model-specific parameters diff --git a/UltrasoundSegmentation/train_nnunet_config.yaml b/UltrasoundSegmentation/configs/train_nnunet_config.yaml similarity index 100% rename from UltrasoundSegmentation/train_nnunet_config.yaml rename to UltrasoundSegmentation/configs/train_nnunet_config.yaml diff --git a/UltrasoundSegmentation/datasets.py b/UltrasoundSegmentation/datasets.py index eb67b4f..87f74a9 100644 --- a/UltrasoundSegmentation/datasets.py +++ b/UltrasoundSegmentation/datasets.py @@ -80,6 +80,10 @@ def __getitem__(self, index): class SlidingWindowTrackedUSDataset(Dataset): + GT_CHANNEL_IDX_FIRST = 0 + GT_CHANNEL_IDX_MIDDLE = 1 + GT_CHANNEL_IDX_LAST = 2 + def __init__( self, root_folder, @@ -87,7 +91,8 @@ def __init__( gts_dir="labels", tfms_dir="transforms", transform=None, - window_size=4 + window_size=5, + gt_idx=GT_CHANNEL_IDX_LAST ): # get names of subfolders in imgs_dir, gts_dir, and tfms_dir image_scans = [ @@ -126,6 +131,16 @@ def __init__( self.transform = transform self.window_size = window_size + # which frame to use for ground truth + if gt_idx == self.GT_CHANNEL_IDX_FIRST: + self.gt_idx = 0 + elif gt_idx == self.GT_CHANNEL_IDX_MIDDLE: + self.gt_idx = window_size // 2 + elif gt_idx == self.GT_CHANNEL_IDX_LAST: + self.gt_idx = window_size - 1 + else: + raise ValueError("Invalid gt_idx value. Must be 0, 1, or 2.") + def __len__(self): return sum( len(self.data[scan]["image"]) - self.window_size + 1 @@ -146,8 +161,8 @@ def __getitem__(self, index): for i in range(self.window_size) ], axis=-1) # shape: (H, W, window_size) - # only take middle frame as label - label = np.load(self.data[scan]["label"][index + self.window_size // 2]) + # get gt image + label = np.load(self.data[scan]["label"][index + self.gt_idx]) # If segmentation_data is 2D, add a channel dimension as last dimension if len(label.shape) == 2: label = np.expand_dims(label, axis=-1) diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 65a9227..a76a4d7 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -6,9 +6,7 @@ - Log training metrics to Weights & Biases """ -import matplotlib -matplotlib.use("Agg") # Use non-interactive backend to avoid error when running on server without GUI - +import importlib import argparse import logging import monai @@ -22,6 +20,8 @@ import wandb import numpy as np import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") # Use non-interactive backend to avoid error when running on server without GUI from tqdm import tqdm from time import perf_counter @@ -44,17 +44,15 @@ ConfusionMatrixMetric ) +import datasets from lr_scheduler import PolyLRScheduler, LinearWarmupWrapper -from datasets import UltrasoundDataset, SlidingWindowTrackedUSDataset # Parse command line arguments def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--train-data-folder", type=str) - parser.add_argument("--val-data-folder", type=str) parser.add_argument("--output-dir", type=str) - parser.add_argument("--config-file", type=str, default="train_config.yaml") + parser.add_argument("--config-file", type=str, default="configs/train_config.yaml") parser.add_argument("--num-sample-images", type=int, default=3) parser.add_argument("--num-fps-test-images", type=int, default=100) parser.add_argument("--save-torchscript", action="store_true") @@ -90,7 +88,7 @@ def main(args): if args.wandb_exp_name is not None: experiment_name = f"{args.wandb_exp_name}_{timestamp}" else: - experiment_name = f"{config['model_name']}_{config['loss_function']}_{timestamp}" + experiment_name = f"{config['network']}_{config['loss_function']}_{timestamp}" run = wandb.init( # Set the project where this run will be logged project=args.wandb_project_name, @@ -182,15 +180,26 @@ def main(args): train_transform = Compose(train_transform_list) val_transform = Compose(val_transform_list) - train_dataset = UltrasoundDataset(args.train_data_folder, transform=train_transform) - val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform) + # Create datasets + dataset_cfg = config["dataset"] + dataset_params = dataset_cfg.get("params", {}) + train_dataset = getattr(datasets, dataset_cfg["name"])( + dataset_cfg["train_folder"], + transform=train_transform, + **(dataset_params if dataset_params else {}) + ) + val_dataset = getattr(datasets, dataset_cfg["name"])( + dataset_cfg["val_folder"], + transform=val_transform, + **(dataset_params if dataset_params else {}) + ) # Create dataloaders using UltrasoundDataset train_dataloader = DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=config["shuffle"], - num_workers=2, + num_workers=4, generator=g ) val_dataloader = DataLoader( @@ -226,12 +235,13 @@ def main(args): softmax=use_softmax, lambda_dice=(1.0 - config["lambda_ce"]), lambda_ce=config["lambda_ce"], - ce_weight=ce_weight + weight=ce_weight ) # Construct model + use_tracking = False dropout_rate = config["dropout_rate"] if "dropout_rate" in config else 0.0 - if config["model_name"].lower() == "attentionunet": + if config["network"].lower() == "attentionunet": model = monai.networks.nets.AttentionUnet( spatial_dims=2, in_channels=config["in_channels"], @@ -240,20 +250,20 @@ def main(args): strides=(2, 2, 2, 2), dropout=dropout_rate ) - elif config["model_name"].lower() == "effnetunet": + elif config["network"].lower() == "effnetunet": model = monai.networks.nets.FlexibleUNet( in_channels=config["in_channels"], out_channels=config["out_channels"], backbone="efficientnet-b4", pretrained=True ) - elif config["model_name"].lower() == "unetplusplus": + elif config["network"].lower() == "unetplusplus": model = monai.networks.nets.BasicUNetPlusPlus( spatial_dims=2, in_channels=config["in_channels"], out_channels=config["out_channels"] ) - elif config["model_name"].lower() == "unetr": + elif config["network"].lower() == "unetr": model = monai.networks.nets.UNETR( in_channels=config["in_channels"], out_channels=config["out_channels"], @@ -261,19 +271,42 @@ def main(args): dropout_rate=dropout_rate, spatial_dims=2 ) - elif config["model_name"].lower() == "swinunetr": + elif config["network"].lower() == "swinunetr": model = monai.networks.nets.SwinUNETR( img_size=config["image_size"], in_channels=config["in_channels"], out_channels=config["out_channels"], spatial_dims=2 ) - elif config["model_name"].lower() == "segresnet": + elif config["network"].lower() == "segresnet": model = monai.networks.nets.SegResNet( spatial_dims=2, in_channels=config["in_channels"], out_channels=config["out_channels"] ) + elif config["network"].lower() == "custom": + try: + model_cfg = config["model"] + model_path = model_cfg["model_path"] + except KeyError: + logging.error("Custom model not found in config file.") + return + + # load custom module + module_name = os.path.basename(model_path).split(".")[0] + loader = importlib.machinery.SourceFileLoader(module_name, model_path) + spec = importlib.util.spec_from_loader(loader.name, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + + # instantiate model + model_params = model_cfg.get("params", {}) + model = getattr(module, model_cfg["name"])( + in_channels=config["in_channels"], + out_channels=config["out_channels"], + **(model_params if model_params else {}) + ) + use_tracking = model_cfg["use_tracking"] else: # default to unet model = monai.networks.nets.UNet( spatial_dims=2, @@ -326,7 +359,7 @@ def main(args): # cosine annealing with warmup warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 250 last_cosine_step = start_step - warmup_steps if start_step > warmup_steps else 0 - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + lr_scheduler = CosineAnnealingLR( optimizer, max_steps - warmup_steps, last_epoch=last_cosine_step - 1 @@ -368,13 +401,17 @@ def main(args): labels = batch["label"].to(device=device) if config["out_channels"] > 1: labels = monai.networks.one_hot(labels, num_classes=config["out_channels"]) - outputs = model(inputs) + if use_tracking: + tfms = batch["transform"].to(device=device) + outputs = model(inputs, tfms) + else: + outputs = model(inputs) if isinstance(outputs, list): # for unet++ output outputs = outputs[0] loss = loss_function(outputs, labels) loss.backward() - scheduler.step() optimizer.step() + scheduler.step() optimizer.zero_grad() epoch_loss += loss.item() epoch_loss /= step @@ -391,7 +428,11 @@ def main(args): val_labels = val_batch["label"].to(device=device) if config["out_channels"] > 1: val_labels = monai.networks.one_hot(val_labels, num_classes=config["out_channels"]) - val_outputs = model(val_inputs) + if use_tracking: + val_tfms = val_batch["transform"].to(device=device) + val_outputs = model(val_inputs, val_tfms) + else: + val_outputs = model(val_inputs) if isinstance(val_outputs, list): val_outputs = val_outputs[0] loss = loss_function(val_outputs, val_labels) @@ -438,7 +479,11 @@ def main(args): inputs = torch.stack([val_dataset[i]["image"] for i in sample]) labels = torch.stack([val_dataset[i]["label"] for i in sample]) with torch.no_grad(): - outputs = model(inputs.to(device=device)) + if use_tracking: + tfms = torch.stack([torch.from_numpy(val_dataset[i]["transform"]) for i in sample]) + outputs = model(inputs.to(device=device), tfms.to(device=device)) + else: + outputs = model(inputs.to(device=device)) if isinstance(outputs, list): outputs = outputs[0] if isinstance(labels, list): @@ -508,9 +553,14 @@ def main(args): model.eval() # disable dropout and batchnorm ts_model_path = os.path.join(run_dir, "model_traced.pt") model = model.to("cpu") - example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) + if use_tracking: + example_input = (torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]), + torch.rand(1, config["in_channels"], 4, 4)) + d = {"shape": example_input[0].shape, "use_tracking": True} + else: + example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) + d = {"shape": example_input.shape, "use_tracking": False} traced_script_module = torch.jit.trace(model, example_input) - d = {"shape": example_input.shape} extra_files = {"config.json": json.dumps(d)} traced_script_module.save(ts_model_path, _extra_files=extra_files) @@ -530,12 +580,19 @@ def main(args): logging.info("Measuring inference time...") num_test_images = args.num_fps_test_images inputs = torch.stack([val_dataset[i]["image"] for i in range(num_test_images)]) + tfms = torch.stack([torch.from_numpy(val_dataset[i]["transform"]) for i in range(num_test_images)]) model.to(device) model.eval() with torch.no_grad(): start = perf_counter() for i in range(num_test_images): - model(inputs[i, :, :, :].unsqueeze(0).to(device=device)) + if use_tracking: + model( + inputs[i, :, :, :].unsqueeze(0).to(device=device), + tfms[i, :, :, :].unsqueeze(0).to(device=device) + ) + else: + model(inputs[i, :, :, :].unsqueeze(0).to(device=device)) end = perf_counter() avg_inf_time = (end - start) / num_test_images avg_inf_fps = 1 / avg_inf_time diff --git a/UltrasoundSegmentation/train_config.yaml b/UltrasoundSegmentation/train_config.yaml deleted file mode 100644 index d608af4..0000000 --- a/UltrasoundSegmentation/train_config.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# Example config file for train.py - -model_name: "unet" # available networks: unet, attention_unet, effnetunet, unetplusplus, unetr, segresnet -loss_function: "DiceCE" -lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal -class_weights: [0.25, 0.75] # must be same length as out_channels -image_size: 128 -in_channels: !!int 1 -out_channels: !!int 2 -num_epochs: !!int 100 -batch_size: !!int 64 -learning_rate: !!float 0.0001 -learning_rate_decay_factor: !!float 0.5 -learning_rate_decay_frequency: !!int 10 -weight_decay: 0.01 -shuffle: !!bool true # true/false -seed: !!int 42 -transforms: - general: - # Basic transforms, do not modify - - name: "Transposed" - params: - keys: ["image", "label"] - indices: [2, 0, 1] - - name: "ToTensord" - - name: "EnsureTyped" - params: - keys: ["image", "label"] - dtype: "float32" - - # Add additional transforms here - - name: "Resized" - params: - keys: ["image", "label"] - spatial_size: [128, 128] - train: diff --git a/UltrasoundSegmentation/train_nnunet.py b/UltrasoundSegmentation/train_nnunet.py index e3e5026..1a9f27b 100644 --- a/UltrasoundSegmentation/train_nnunet.py +++ b/UltrasoundSegmentation/train_nnunet.py @@ -48,7 +48,7 @@ def parse_args(): parser.add_argument("--train-data-folder", type=str) parser.add_argument("--val-data-folder", type=str) parser.add_argument("--output-dir", type=str) - parser.add_argument("--config-file", type=str, default="train_config.yaml") + parser.add_argument("--config-file", type=str, default="configs/train_nnunet_config.yaml") parser.add_argument("--num-sample-images", type=int, default=3) parser.add_argument("--num-fps-test-images", type=int, default=100) parser.add_argument("--save-torchscript", action="store_true")