From 491384f8e6546f453a8bc2784f41682fea8157ae Mon Sep 17 00:00:00 2001 From: chriscyyeung Date: Fri, 14 Jun 2024 11:55:24 -0400 Subject: [PATCH] Updated TorchSequenceSegmentation to include all models in folder. Added edge erosion to curvilinear mask. --- .../Resources/UI/TorchSequenceSegmentation.ui | 68 +++++++++++++++---- .../TorchSequenceSegmentation.py | 33 +++++++-- 2 files changed, 83 insertions(+), 18 deletions(-) diff --git a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui index 45a4230..05f29f8 100644 --- a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui +++ b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui @@ -184,6 +184,9 @@ false + + 100 + @@ -511,6 +514,34 @@ + + + + Apply log transform: + + + + + + + + + + + + + + Normalize input to [0,1]: + + + + + + + + + + @@ -604,31 +635,40 @@ - - + + - Apply log transform: + Mask edge erosion x (%): - - - - + + + + 2 + + + 1.000000000000000 + + + 0.100000000000000 - - + + - Normalize input to [0,1]: + Mask edge erosion y (%): - - - - + + + + 1.000000000000000 + + + 0.100000000000000 diff --git a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py index dcc71ea..0a4b8e6 100644 --- a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py +++ b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py @@ -223,6 +223,8 @@ def setup(self): self.ui.verticalFlipCheckbox.connect("toggled(bool)", self.updateParameterNodeFromGUI) self.ui.modelInputSizeSpinbox.connect("valueChanged(int)", self.updateParameterNodeFromGUI) self.ui.applyLogCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI) + self.ui.edgeErosionXSpinBox.connect("valueChanged(float)", self.onErodeEdgeX) + self.ui.edgeErosionYSpinBox.connect("valueChanged(float)", self.onErodeEdgeY) lastNormalizeSetting = slicer.util.settingsValue(self.logic.LAST_NORMALIZE_SETTING, False, converter=slicer.util.toBool) self.ui.normalizeCheckBox.checked = lastNormalizeSetting @@ -236,7 +238,7 @@ def setup(self): models = self.logic.getAllModelPaths() self.ui.modelComboBox.clear() for model in models: - self.ui.modelComboBox.addItem(model.split(os.sep)[-2], model) + self.ui.modelComboBox.addItem(model.split(os.sep)[-1], model) self.ui.modelDirectoryButton.connect("directoryChanged(const QString)", self.updateSettingsFromGUI) # Set last scan conversion path in UI @@ -479,6 +481,9 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetParameter("ApplyLogTransform", "true" if self.ui.applyLogCheckBox.checked else "false") self._parameterNode.SetParameter("ModelInputSize", str(self.ui.modelInputSizeSpinbox.value)) + # Update edge erosion parameters + self.logic.erodeCurvilinearMask(self.ui.edgeErosionXSpinBox.value, self.ui.edgeErosionYSpinBox.value) + # Update individual model to use if self.ui.useIndividualRadioButton.checked: if self.ui.modelComboBox.count > 0: @@ -505,7 +510,7 @@ def updateSettingsFromGUI(self, caller=None, event=None): self.logic.setModelsToUse(models) self.ui.modelComboBox.clear() for model in models: - self.ui.modelComboBox.addItem(model.split(os.sep)[-2], model) + self.ui.modelComboBox.addItem(model.split(os.sep)[-1], model) # Update output folder path outputFolder = self.ui.outputDirectoryButton.directory @@ -522,6 +527,12 @@ def onClearScanConversion(self): settings = qt.QSettings() settings.setValue(self.logic.LAST_SCAN_CONVERSION_PATH_SETTING, "") self.logic.loadScanConversion(None) + + def onErodeEdgeX(self, value): + self.logic.erodeCurvilinearMask(value, self.ui.edgeErosionYSpinBox.value) + + def onErodeEdgeY(self, value): + self.logic.erodeCurvilinearMask(self.ui.edgeErosionXSpinBox.value, value) def onModelSelectionMethodChanged(self, caller=None, event=None): useIndividualModel = self.ui.useIndividualRadioButton.checked @@ -778,7 +789,7 @@ def setDefaultParameters(self, parameterNode): def getAllModelPaths(self): modelFolder = slicer.util.settingsValue(self.LAST_MODEL_FOLDER_SETTING, "") if modelFolder: - models = glob.glob(os.path.join(modelFolder, "**", "*traced*.pt"), recursive=True) + models = glob.glob(os.path.join(modelFolder, "**", "*.pt"), recursive=True) normModels = [os.path.normpath(model) for model in models] # normalize paths return normModels else: @@ -870,11 +881,25 @@ def loadScanConversion(self, scanConversionPath): self.curvilinear_mask = cv2.circle(self.curvilinear_mask, (center_coordinate_pixel[1], center_coordinate_pixel[0]), radius_start_px + 1, 0, -1) + self.curvilinear_mask = self.curvilinear_mask.astype(np.uint8) # Convert mask_array to uint8 def scanConvert(self, linearArray): z = linearArray.flatten() zi = np.einsum("ij,ij->i", np.take(z, self.vertices), self.weights) return zi.reshape(self.curvilinear_size, self.curvilinear_size) + + def erodeCurvilinearMask(self, edgeErosionX, edgeErosionY): + # Erode mask by 10 percent of the image size to remove artifacts on the edges + if self.curvilinear_mask is not None and (edgeErosionX > 0 or edgeErosionY > 0): + # Repaint the borders of the mask to zero to allow erosion from all sides + self.curvilinear_mask[0, :] = 0 + self.curvilinear_mask[:, 0] = 0 + self.curvilinear_mask[-1, :] = 0 + self.curvilinear_mask[:, -1] = 0 + # Erode the mask + erosionSizeX = int(edgeErosionX * self.curvilinear_size) + erosionSizeY = int(edgeErosionY * self.curvilinear_size) + self.curvilinear_mask = cv2.erode(self.curvilinear_mask, np.ones((erosionSizeX, erosionSizeY), np.uint8), iterations=1) def getUniqueName(self, node, baseName): newName = baseName @@ -952,7 +977,7 @@ def getPrediction(self, image): if inputArray.max() <= 1.0: logging.info("Input image is already between 0 and 1, skipping normalization.") else: - inputArray = inputArray.astype(float) / inputArray.max() + inputArray = inputArray.astype(float) / 255.0 # Convert to tensor and add batch dimension inputTensor = torch.from_numpy(inputArray).unsqueeze(0).unsqueeze(0).float().to(DEVICE)