Skip to content

Commit

Permalink
Added option to resume training. Also now saves the model after each …
Browse files Browse the repository at this point in the history
…epoch. Modified UltrasoundDataset to be more flexible to other folder structures. Added several optimizer/scheduler options.
  • Loading branch information
chriscyyeung committed Feb 28, 2024
1 parent d5e714f commit 9a87285
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 66 deletions.
8 changes: 4 additions & 4 deletions UltrasoundSegmentation/UltrasoundDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class UltrasoundDataset(Dataset):
Dataset class for ultrasound images, segmentations, and transformations.
"""

def __init__(self, data_folder, transform=None):
def __init__(self, root_folder, imgs_dir="images", gts_dir="labels", tfms_dir="transforms", transform=None):
self.transform = transform

# Find all data segmentation files and matching ultrasound files in input directory
self.images = sorted(glob.glob(os.path.join(data_folder, "**", "*_ultrasound*.npy"), recursive=True))
self.segmentations = sorted(glob.glob(os.path.join(data_folder, "**", "*_segmentation*.npy"), recursive=True))
self.tfm_matrices = sorted(glob.glob(os.path.join(data_folder, "**", "*_transform*.npy"), recursive=True))
self.images = glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True)
self.segmentations = glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True)
self.tfm_matrices = glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True)

def __len__(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion UltrasoundSegmentation/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ dependencies:
- pip:
- requests
- opencv-python
- monai[scikit-image, tqdm, pyyaml, matplotlib, einops]
- monai[skimage, scipy, tqdm, pyyaml, matplotlib, einops]
38 changes: 38 additions & 0 deletions UltrasoundSegmentation/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torch.optim.lr_scheduler import _LRScheduler


class PolyLRScheduler(_LRScheduler):
"""Adapted from https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/lr_scheduler/polylr.py"""
def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, last_step: int = -1):
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_steps = max_steps
self.exponent = exponent
self.last_step = last_step
super().__init__(optimizer, last_step, False)

def step(self, epoch=None):
self.last_step += 1
new_lr = self.initial_lr * (1 - self.last_step / self.max_steps) ** self.exponent
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr


class LinearWarmupWrapper(_LRScheduler):
"""Wrapper for a PyTorch scheduler to add a linear LR warmup."""
def __init__(self, optimizer, scheduler, initial_lr, warmup_steps, last_step=-1):
self.optimizer = optimizer
self.scheduler = scheduler
self.initial_lr = initial_lr
self.warmup_steps = warmup_steps
self.last_step = last_step
super().__init__(optimizer, last_step, False)

def step(self, epoch=None):
self.last_step += 1
if self.last_step <= self.warmup_steps:
warmup_factor = min(1.0, (self.last_step + 1) / self.warmup_steps)
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.initial_lr * warmup_factor
else:
self.scheduler.step()
182 changes: 121 additions & 61 deletions UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from tqdm import tqdm
from time import perf_counter
from datetime import datetime
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

from monai.data import DataLoader
from monai.data.utils import decollate_batch
Expand All @@ -39,10 +39,12 @@
)
from monai.metrics import (
DiceMetric,
MeanIoU,
MeanIoU,
HausdorffDistanceMetric,
ConfusionMatrixMetric
)

from lr_scheduler import PolyLRScheduler, LinearWarmupWrapper
from UltrasoundDataset import UltrasoundDataset


Expand All @@ -62,6 +64,7 @@ def parse_args():
parser.add_argument("--wandb-exp-name", type=str)
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--save-log", action="store_true")
parser.add_argument("--resume-ckpt", type=str)
try:
return parser.parse_args()
except SystemExit as err:
Expand Down Expand Up @@ -198,10 +201,36 @@ def main(args):
generator=g
)

# Construct model
# Construct loss function
use_sigmoid = True if config["out_channels"] == 1 else False
use_softmax = True if config["out_channels"] > 1 else False
ce_weight = torch.tensor(config["class_weights"], device=device) \
if config["out_channels"] > 1 else None
if config["loss_function"].lower() == "dicefocal":
loss_function = monai.losses.DiceFocalLoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
lambda_dice=(1.0 - config["lambda_ce"]),
lambda_focal=config["lambda_ce"]
)
elif config["loss_function"].lower() == "tversky":
loss_function = monai.losses.TverskyLoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
alpha=0.3,
beta=0.7 # best values from original paper
)
else: # default to dice + cross entropy
loss_function = monai.losses.DiceCELoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
lambda_dice=(1.0 - config["lambda_ce"]),
lambda_ce=config["lambda_ce"],
ce_weight=ce_weight
)

# Construct model
dropout_rate = config["dropout_rate"] if "dropout_rate" in config else 0.0

if config["model_name"].lower() == "attentionunet":
model = monai.networks.nets.AttentionUnet(
spatial_dims=2,
Expand Down Expand Up @@ -255,55 +284,66 @@ def main(args):
num_res_units=config["num_res_units"] if "num_res_units" in config else 2,
dropout=dropout_rate
)
model = model.to(device=device)

# Construct loss function
use_sigmoid = True if config["out_channels"] == 1 else False
use_softmax = True if config["out_channels"] > 1 else False
ce_weight = torch.tensor(config["class_weights"], device=device) \
if config["out_channels"] > 1 else None
if config["loss_function"].lower() == "dicefocal":
loss_function = monai.losses.DiceFocalLoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
lambda_dice=(1.0 - config["lambda_ce"]),
lambda_focal=config["lambda_ce"]
)
elif config["loss_function"].lower() == "tversky":
loss_function = monai.losses.TverskyLoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
alpha=0.3,
beta=0.7 # best values from original paper
)
else: # default to dice + cross entropy
loss_function = monai.losses.DiceCELoss(
sigmoid=use_sigmoid,
softmax=use_softmax,
lambda_dice=(1.0 - config["lambda_ce"]),
lambda_ce=config["lambda_ce"],
ce_weight=ce_weight
)
model = model.to(device=device)
# optimizer = Adam(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"])
optimizer = AdamW(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"])

optimizer = Adam(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"])
# resume training
if args.resume_ckpt:
try:
state = torch.load(args.resume_ckpt)
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
start_epoch = state["epoch"] + 1
best_val_loss = state["val_loss"]
logging.info(f"Loaded model from {args.resume_ckpt}.")
except Exception as e:
logging.error(f"Failed to load model from {args.resume_ckpt}: {e}.")
else:
start_epoch = 0
best_val_loss = np.inf

# Set up learning rate decay
# Set up learning rate scheduler
try:
learning_rate_decay_frequency = int(config["learning_rate_decay_frequency"])
except ValueError:
except Exception:
learning_rate_decay_frequency = 100
try:
learning_rate_decay_factor = float(config["learning_rate_decay_factor"])
except ValueError:
except Exception:
learning_rate_decay_factor = 1.0 # No decay
logging.info(f"Learning rate decay frequency: {learning_rate_decay_frequency}")
logging.info(f"Learning rate decay factor: {learning_rate_decay_factor}")
scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor)
# logging.info(f"Learning rate decay frequency: {learning_rate_decay_frequency}")
# logging.info(f"Learning rate decay factor: {learning_rate_decay_factor}")
# scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor)

# next two schedulers use minibatch as step, not epoch
# need to move scheduler.step() to inner loop
start_step = start_epoch * len(train_dataloader)
max_steps = len(train_dataloader) * config["num_epochs"]
# scheduler = PolyLRScheduler(optimizer, config["learning_rate"], max_steps, last_step=start_step - 1)

# cosine annealing with warmup
warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 250
last_cosine_step = start_step - warmup_steps if start_step > warmup_steps else 0
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
max_steps - warmup_steps,
last_epoch=last_cosine_step - 1
)
scheduler = LinearWarmupWrapper(
optimizer,
lr_scheduler,
config["learning_rate"],
warmup_steps=warmup_steps,
last_step=start_step - 1
)

# Metrics
include_background = True if config["out_channels"] == 1 else False
dice_metric = DiceMetric(include_background=include_background, reduction="mean")
iou_metric = MeanIoU(include_background=include_background, reduction="mean")
hd95_metric = HausdorffDistanceMetric(include_background=include_background, percentile=95.0, reduction="mean")
confusion_matrix_metric = ConfusionMatrixMetric(
include_background=include_background,
metric_name=["accuracy", "precision", "sensitivity", "specificity", "f1_score"],
Expand All @@ -317,8 +357,8 @@ def main(args):

# Train model
epochs = config["num_epochs"]
for epoch in range(epochs):
logging.info(f"Epoch {epoch+1}/{epochs}")
for epoch in range(start_epoch, epochs):
logging.info(f"Epoch {epoch + 1}/{epochs}")
model.train()
epoch_loss = 0
step = 0
Expand All @@ -328,18 +368,17 @@ def main(args):
labels = batch["label"].to(device=device)
if config["out_channels"] > 1:
labels = monai.networks.one_hot(labels, num_classes=config["out_channels"])
optimizer.zero_grad()
outputs = model(inputs)
if isinstance(outputs, list): # for unet++ output
outputs = outputs[0]
loss = loss_function(outputs, labels)
loss.backward()
scheduler.step()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
epoch_loss /= step
logging.info(f"Training loss: {epoch_loss}")
if config["model_name"].lower() != "nnunet":
scheduler.step()

# Validation step
model.eval()
Expand All @@ -364,23 +403,27 @@ def main(args):

dice_metric(y_pred=val_outputs, y=val_labels)
iou_metric(y_pred=val_outputs, y=val_labels)
hd95_metric(y_pred=val_outputs, y=val_labels)
confusion_matrix_metric(y_pred=val_outputs, y=val_labels)

val_loss /= val_step
dice = dice_metric.aggregate().item()
iou = iou_metric.aggregate().item()
hd95 = hd95_metric.aggregate().item()
cm = confusion_matrix_metric.aggregate()

# reset status for next validation round
dice_metric.reset()
iou_metric.reset()
hd95_metric.reset()
confusion_matrix_metric.reset()

logging.info(
f"Validation results:\n"
f"\tLoss: {val_loss}\n"
f"\tDice: {dice}\n"
f"\tIoU: {iou}\n"
f"\t95% HD: {hd95}\n"
f"\tAccuracy: {(acc := cm[0].item())}\n"
f"\tPrecision: {(pre := cm[1].item())}\n"
f"\tSensitivity: {(sen := cm[2].item())}\n"
Expand Down Expand Up @@ -429,6 +472,7 @@ def main(args):
"val_loss": val_loss,
"dice": dice,
"iou": iou,
"95hd": hd95,
"accuracy": acc,
"precision": pre,
"sensitivity": sen,
Expand All @@ -449,22 +493,38 @@ def main(args):
torch.save(model.state_dict(), ckpt_model_path)
logging.info(f"Saved model checkpoint to {ckpt_model_path}.")

# Save the final model also under the name "model.pt" so that we can easily find it later.
# This is useful if we want to use the model for inference without having to specify the model filename.
model_path = os.path.join(run_dir, "model.pt")
torch.save(model.state_dict(), model_path)
logging.info(f"Saved model to {model_path}.")

# Save model as TorchScript
if args.save_torchscript:
ts_model_path = os.path.join(run_dir, "model_traced.pt")
model = model.to("cpu")
example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"])
traced_script_module = torch.jit.trace(model, example_input)
d = {"shape": example_input.shape}
extra_files = {"config.json": json.dumps(d)}
traced_script_module.save(ts_model_path, _extra_files=extra_files)
logging.info(f"Saved traced model to {ts_model_path}.")
# Keep latest model as just model.pt, which can be used for inference or to resume training
model_path = os.path.join(run_dir, "model.pt")
state = {
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"val_loss": val_loss
}
torch.save(state, model_path)

# Save latest model as TorchScript
if args.save_torchscript:
model.eval() # disable dropout and batchnorm
ts_model_path = os.path.join(run_dir, "model_traced.pt")
model = model.to("cpu")
example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"])
traced_script_module = torch.jit.trace(model, example_input)
d = {"shape": example_input.shape}
extra_files = {"config.json": json.dumps(d)}
traced_script_module.save(ts_model_path, _extra_files=extra_files)

# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = os.path.join(run_dir, "model_best.pt")
torch.save(state, best_model_path)

if args.save_torchscript:
best_ts_model_path = os.path.join(run_dir, "model_traced_best.pt")
traced_script_module.save(best_ts_model_path, _extra_files=extra_files)

model = model.to(device=device) # return model to original device

# Test inference time (load images before loop to exclude from time measurement)
logging.info("Measuring inference time...")
Expand Down

0 comments on commit 9a87285

Please sign in to comment.