From 9a8728589ae1f6ead104210ed5ced5d15f5f4ef0 Mon Sep 17 00:00:00 2001 From: chriscyyeung Date: Wed, 28 Feb 2024 14:51:59 -0500 Subject: [PATCH] Added option to resume training. Also now saves the model after each epoch. Modified UltrasoundDataset to be more flexible to other folder structures. Added several optimizer/scheduler options. --- UltrasoundSegmentation/UltrasoundDataset.py | 8 +- UltrasoundSegmentation/environment.yml | 2 +- UltrasoundSegmentation/lr_scheduler.py | 38 ++++ UltrasoundSegmentation/train.py | 182 +++++++++++++------- 4 files changed, 164 insertions(+), 66 deletions(-) create mode 100644 UltrasoundSegmentation/lr_scheduler.py diff --git a/UltrasoundSegmentation/UltrasoundDataset.py b/UltrasoundSegmentation/UltrasoundDataset.py index ab9426d..375dcb2 100644 --- a/UltrasoundSegmentation/UltrasoundDataset.py +++ b/UltrasoundSegmentation/UltrasoundDataset.py @@ -11,13 +11,13 @@ class UltrasoundDataset(Dataset): Dataset class for ultrasound images, segmentations, and transformations. """ - def __init__(self, data_folder, transform=None): + def __init__(self, root_folder, imgs_dir="images", gts_dir="labels", tfms_dir="transforms", transform=None): self.transform = transform # Find all data segmentation files and matching ultrasound files in input directory - self.images = sorted(glob.glob(os.path.join(data_folder, "**", "*_ultrasound*.npy"), recursive=True)) - self.segmentations = sorted(glob.glob(os.path.join(data_folder, "**", "*_segmentation*.npy"), recursive=True)) - self.tfm_matrices = sorted(glob.glob(os.path.join(data_folder, "**", "*_transform*.npy"), recursive=True)) + self.images = glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True) + self.segmentations = glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True) + self.tfm_matrices = glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True) def __len__(self): """ diff --git a/UltrasoundSegmentation/environment.yml b/UltrasoundSegmentation/environment.yml index e54b477..08b9406 100644 --- a/UltrasoundSegmentation/environment.yml +++ b/UltrasoundSegmentation/environment.yml @@ -13,4 +13,4 @@ dependencies: - pip: - requests - opencv-python - - monai[scikit-image, tqdm, pyyaml, matplotlib, einops] \ No newline at end of file + - monai[skimage, scipy, tqdm, pyyaml, matplotlib, einops] \ No newline at end of file diff --git a/UltrasoundSegmentation/lr_scheduler.py b/UltrasoundSegmentation/lr_scheduler.py new file mode 100644 index 0000000..4fe95d1 --- /dev/null +++ b/UltrasoundSegmentation/lr_scheduler.py @@ -0,0 +1,38 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyLRScheduler(_LRScheduler): + """Adapted from https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/lr_scheduler/polylr.py""" + def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, last_step: int = -1): + self.optimizer = optimizer + self.initial_lr = initial_lr + self.max_steps = max_steps + self.exponent = exponent + self.last_step = last_step + super().__init__(optimizer, last_step, False) + + def step(self, epoch=None): + self.last_step += 1 + new_lr = self.initial_lr * (1 - self.last_step / self.max_steps) ** self.exponent + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr + + +class LinearWarmupWrapper(_LRScheduler): + """Wrapper for a PyTorch scheduler to add a linear LR warmup.""" + def __init__(self, optimizer, scheduler, initial_lr, warmup_steps, last_step=-1): + self.optimizer = optimizer + self.scheduler = scheduler + self.initial_lr = initial_lr + self.warmup_steps = warmup_steps + self.last_step = last_step + super().__init__(optimizer, last_step, False) + + def step(self, epoch=None): + self.last_step += 1 + if self.last_step <= self.warmup_steps: + warmup_factor = min(1.0, (self.last_step + 1) / self.warmup_steps) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.initial_lr * warmup_factor + else: + self.scheduler.step() diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 3849306..8366e01 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -26,8 +26,8 @@ from tqdm import tqdm from time import perf_counter from datetime import datetime -from torch.optim import Adam -from torch.optim.lr_scheduler import StepLR +from torch.optim import Adam, AdamW +from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR from monai.data import DataLoader from monai.data.utils import decollate_batch @@ -39,10 +39,12 @@ ) from monai.metrics import ( DiceMetric, - MeanIoU, + MeanIoU, + HausdorffDistanceMetric, ConfusionMatrixMetric ) +from lr_scheduler import PolyLRScheduler, LinearWarmupWrapper from UltrasoundDataset import UltrasoundDataset @@ -62,6 +64,7 @@ def parse_args(): parser.add_argument("--wandb-exp-name", type=str) parser.add_argument("--log-level", type=str, default="INFO") parser.add_argument("--save-log", action="store_true") + parser.add_argument("--resume-ckpt", type=str) try: return parser.parse_args() except SystemExit as err: @@ -198,10 +201,36 @@ def main(args): generator=g ) - # Construct model + # Construct loss function + use_sigmoid = True if config["out_channels"] == 1 else False + use_softmax = True if config["out_channels"] > 1 else False + ce_weight = torch.tensor(config["class_weights"], device=device) \ + if config["out_channels"] > 1 else None + if config["loss_function"].lower() == "dicefocal": + loss_function = monai.losses.DiceFocalLoss( + sigmoid=use_sigmoid, + softmax=use_softmax, + lambda_dice=(1.0 - config["lambda_ce"]), + lambda_focal=config["lambda_ce"] + ) + elif config["loss_function"].lower() == "tversky": + loss_function = monai.losses.TverskyLoss( + sigmoid=use_sigmoid, + softmax=use_softmax, + alpha=0.3, + beta=0.7 # best values from original paper + ) + else: # default to dice + cross entropy + loss_function = monai.losses.DiceCELoss( + sigmoid=use_sigmoid, + softmax=use_softmax, + lambda_dice=(1.0 - config["lambda_ce"]), + lambda_ce=config["lambda_ce"], + ce_weight=ce_weight + ) + # Construct model dropout_rate = config["dropout_rate"] if "dropout_rate" in config else 0.0 - if config["model_name"].lower() == "attentionunet": model = monai.networks.nets.AttentionUnet( spatial_dims=2, @@ -255,55 +284,66 @@ def main(args): num_res_units=config["num_res_units"] if "num_res_units" in config else 2, dropout=dropout_rate ) - model = model.to(device=device) - # Construct loss function - use_sigmoid = True if config["out_channels"] == 1 else False - use_softmax = True if config["out_channels"] > 1 else False - ce_weight = torch.tensor(config["class_weights"], device=device) \ - if config["out_channels"] > 1 else None - if config["loss_function"].lower() == "dicefocal": - loss_function = monai.losses.DiceFocalLoss( - sigmoid=use_sigmoid, - softmax=use_softmax, - lambda_dice=(1.0 - config["lambda_ce"]), - lambda_focal=config["lambda_ce"] - ) - elif config["loss_function"].lower() == "tversky": - loss_function = monai.losses.TverskyLoss( - sigmoid=use_sigmoid, - softmax=use_softmax, - alpha=0.3, - beta=0.7 # best values from original paper - ) - else: # default to dice + cross entropy - loss_function = monai.losses.DiceCELoss( - sigmoid=use_sigmoid, - softmax=use_softmax, - lambda_dice=(1.0 - config["lambda_ce"]), - lambda_ce=config["lambda_ce"], - ce_weight=ce_weight - ) + model = model.to(device=device) + # optimizer = Adam(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"]) + optimizer = AdamW(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"]) - optimizer = Adam(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"]) + # resume training + if args.resume_ckpt: + try: + state = torch.load(args.resume_ckpt) + model.load_state_dict(state["model"]) + optimizer.load_state_dict(state["optimizer"]) + start_epoch = state["epoch"] + 1 + best_val_loss = state["val_loss"] + logging.info(f"Loaded model from {args.resume_ckpt}.") + except Exception as e: + logging.error(f"Failed to load model from {args.resume_ckpt}: {e}.") + else: + start_epoch = 0 + best_val_loss = np.inf - # Set up learning rate decay + # Set up learning rate scheduler try: learning_rate_decay_frequency = int(config["learning_rate_decay_frequency"]) - except ValueError: + except Exception: learning_rate_decay_frequency = 100 try: learning_rate_decay_factor = float(config["learning_rate_decay_factor"]) - except ValueError: + except Exception: learning_rate_decay_factor = 1.0 # No decay - logging.info(f"Learning rate decay frequency: {learning_rate_decay_frequency}") - logging.info(f"Learning rate decay factor: {learning_rate_decay_factor}") - scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor) + # logging.info(f"Learning rate decay frequency: {learning_rate_decay_frequency}") + # logging.info(f"Learning rate decay factor: {learning_rate_decay_factor}") + # scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor) + + # next two schedulers use minibatch as step, not epoch + # need to move scheduler.step() to inner loop + start_step = start_epoch * len(train_dataloader) + max_steps = len(train_dataloader) * config["num_epochs"] + # scheduler = PolyLRScheduler(optimizer, config["learning_rate"], max_steps, last_step=start_step - 1) + + # cosine annealing with warmup + warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 250 + last_cosine_step = start_step - warmup_steps if start_step > warmup_steps else 0 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + max_steps - warmup_steps, + last_epoch=last_cosine_step - 1 + ) + scheduler = LinearWarmupWrapper( + optimizer, + lr_scheduler, + config["learning_rate"], + warmup_steps=warmup_steps, + last_step=start_step - 1 + ) # Metrics include_background = True if config["out_channels"] == 1 else False dice_metric = DiceMetric(include_background=include_background, reduction="mean") iou_metric = MeanIoU(include_background=include_background, reduction="mean") + hd95_metric = HausdorffDistanceMetric(include_background=include_background, percentile=95.0, reduction="mean") confusion_matrix_metric = ConfusionMatrixMetric( include_background=include_background, metric_name=["accuracy", "precision", "sensitivity", "specificity", "f1_score"], @@ -317,8 +357,8 @@ def main(args): # Train model epochs = config["num_epochs"] - for epoch in range(epochs): - logging.info(f"Epoch {epoch+1}/{epochs}") + for epoch in range(start_epoch, epochs): + logging.info(f"Epoch {epoch + 1}/{epochs}") model.train() epoch_loss = 0 step = 0 @@ -328,18 +368,17 @@ def main(args): labels = batch["label"].to(device=device) if config["out_channels"] > 1: labels = monai.networks.one_hot(labels, num_classes=config["out_channels"]) - optimizer.zero_grad() outputs = model(inputs) if isinstance(outputs, list): # for unet++ output outputs = outputs[0] loss = loss_function(outputs, labels) loss.backward() + scheduler.step() optimizer.step() + optimizer.zero_grad() epoch_loss += loss.item() epoch_loss /= step logging.info(f"Training loss: {epoch_loss}") - if config["model_name"].lower() != "nnunet": - scheduler.step() # Validation step model.eval() @@ -364,16 +403,19 @@ def main(args): dice_metric(y_pred=val_outputs, y=val_labels) iou_metric(y_pred=val_outputs, y=val_labels) + hd95_metric(y_pred=val_outputs, y=val_labels) confusion_matrix_metric(y_pred=val_outputs, y=val_labels) val_loss /= val_step dice = dice_metric.aggregate().item() iou = iou_metric.aggregate().item() + hd95 = hd95_metric.aggregate().item() cm = confusion_matrix_metric.aggregate() # reset status for next validation round dice_metric.reset() iou_metric.reset() + hd95_metric.reset() confusion_matrix_metric.reset() logging.info( @@ -381,6 +423,7 @@ def main(args): f"\tLoss: {val_loss}\n" f"\tDice: {dice}\n" f"\tIoU: {iou}\n" + f"\t95% HD: {hd95}\n" f"\tAccuracy: {(acc := cm[0].item())}\n" f"\tPrecision: {(pre := cm[1].item())}\n" f"\tSensitivity: {(sen := cm[2].item())}\n" @@ -429,6 +472,7 @@ def main(args): "val_loss": val_loss, "dice": dice, "iou": iou, + "95hd": hd95, "accuracy": acc, "precision": pre, "sensitivity": sen, @@ -449,22 +493,38 @@ def main(args): torch.save(model.state_dict(), ckpt_model_path) logging.info(f"Saved model checkpoint to {ckpt_model_path}.") - # Save the final model also under the name "model.pt" so that we can easily find it later. - # This is useful if we want to use the model for inference without having to specify the model filename. - model_path = os.path.join(run_dir, "model.pt") - torch.save(model.state_dict(), model_path) - logging.info(f"Saved model to {model_path}.") - - # Save model as TorchScript - if args.save_torchscript: - ts_model_path = os.path.join(run_dir, "model_traced.pt") - model = model.to("cpu") - example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) - traced_script_module = torch.jit.trace(model, example_input) - d = {"shape": example_input.shape} - extra_files = {"config.json": json.dumps(d)} - traced_script_module.save(ts_model_path, _extra_files=extra_files) - logging.info(f"Saved traced model to {ts_model_path}.") + # Keep latest model as just model.pt, which can be used for inference or to resume training + model_path = os.path.join(run_dir, "model.pt") + state = { + "epoch": epoch, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "val_loss": val_loss + } + torch.save(state, model_path) + + # Save latest model as TorchScript + if args.save_torchscript: + model.eval() # disable dropout and batchnorm + ts_model_path = os.path.join(run_dir, "model_traced.pt") + model = model.to("cpu") + example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) + traced_script_module = torch.jit.trace(model, example_input) + d = {"shape": example_input.shape} + extra_files = {"config.json": json.dumps(d)} + traced_script_module.save(ts_model_path, _extra_files=extra_files) + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + best_model_path = os.path.join(run_dir, "model_best.pt") + torch.save(state, best_model_path) + + if args.save_torchscript: + best_ts_model_path = os.path.join(run_dir, "model_traced_best.pt") + traced_script_module.save(best_ts_model_path, _extra_files=extra_files) + + model = model.to(device=device) # return model to original device # Test inference time (load images before loop to exclude from time measurement) logging.info("Measuring inference time...")