Skip to content

Commit

Permalink
Updated TorchSequenceSegmentation to include all models in folder. Ad…
Browse files Browse the repository at this point in the history
…ded edge erosion to curvilinear mask.
  • Loading branch information
chriscyyeung committed Jun 14, 2024
1 parent db94eaa commit 491384f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@
<property name="editable">
<bool>false</bool>
</property>
<property name="maxVisibleItems">
<number>100</number>
</property>
</widget>
</item>
<item row="6" column="0">
Expand Down Expand Up @@ -511,6 +514,34 @@
</property>
</widget>
</item>
<item row="2" column="0">
<widget class="QLabel" name="label_17">
<property name="text">
<string>Apply log transform:</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QCheckBox" name="applyLogCheckBox">
<property name="text">
<string/>
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QLabel" name="label_18">
<property name="text">
<string>Normalize input to [0,1]:</string>
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QCheckBox" name="normalizeCheckBox">
<property name="text">
<string/>
</property>
</widget>
</item>
<item row="4" column="0">
<widget class="QLabel" name="label_8">
<property name="text">
Expand Down Expand Up @@ -604,31 +635,40 @@
</item>
</layout>
</item>
<item row="2" column="0">
<widget class="QLabel" name="label_17">
<item row="8" column="0">
<widget class="QLabel" name="label_19">
<property name="text">
<string>Apply log transform:</string>
<string>Mask edge erosion x (%):</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QCheckBox" name="applyLogCheckBox">
<property name="text">
<string/>
<item row="8" column="1">
<widget class="QDoubleSpinBox" name="edgeErosionXSpinBox">
<property name="decimals">
<number>2</number>
</property>
<property name="maximum">
<double>1.000000000000000</double>
</property>
<property name="singleStep">
<double>0.100000000000000</double>
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QLabel" name="label_18">
<item row="9" column="0">
<widget class="QLabel" name="label_20">
<property name="text">
<string>Normalize input to [0,1]:</string>
<string>Mask edge erosion y (%):</string>
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QCheckBox" name="normalizeCheckBox">
<property name="text">
<string/>
<item row="9" column="1">
<widget class="QDoubleSpinBox" name="edgeErosionYSpinBox">
<property name="maximum">
<double>1.000000000000000</double>
</property>
<property name="singleStep">
<double>0.100000000000000</double>
</property>
</widget>
</item>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def setup(self):
self.ui.verticalFlipCheckbox.connect("toggled(bool)", self.updateParameterNodeFromGUI)
self.ui.modelInputSizeSpinbox.connect("valueChanged(int)", self.updateParameterNodeFromGUI)
self.ui.applyLogCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI)
self.ui.edgeErosionXSpinBox.connect("valueChanged(float)", self.onErodeEdgeX)
self.ui.edgeErosionYSpinBox.connect("valueChanged(float)", self.onErodeEdgeY)

lastNormalizeSetting = slicer.util.settingsValue(self.logic.LAST_NORMALIZE_SETTING, False, converter=slicer.util.toBool)
self.ui.normalizeCheckBox.checked = lastNormalizeSetting
Expand All @@ -236,7 +238,7 @@ def setup(self):
models = self.logic.getAllModelPaths()
self.ui.modelComboBox.clear()
for model in models:
self.ui.modelComboBox.addItem(model.split(os.sep)[-2], model)
self.ui.modelComboBox.addItem(model.split(os.sep)[-1], model)
self.ui.modelDirectoryButton.connect("directoryChanged(const QString)", self.updateSettingsFromGUI)

# Set last scan conversion path in UI
Expand Down Expand Up @@ -479,6 +481,9 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
self._parameterNode.SetParameter("ApplyLogTransform", "true" if self.ui.applyLogCheckBox.checked else "false")
self._parameterNode.SetParameter("ModelInputSize", str(self.ui.modelInputSizeSpinbox.value))

# Update edge erosion parameters
self.logic.erodeCurvilinearMask(self.ui.edgeErosionXSpinBox.value, self.ui.edgeErosionYSpinBox.value)

# Update individual model to use
if self.ui.useIndividualRadioButton.checked:
if self.ui.modelComboBox.count > 0:
Expand All @@ -505,7 +510,7 @@ def updateSettingsFromGUI(self, caller=None, event=None):
self.logic.setModelsToUse(models)
self.ui.modelComboBox.clear()
for model in models:
self.ui.modelComboBox.addItem(model.split(os.sep)[-2], model)
self.ui.modelComboBox.addItem(model.split(os.sep)[-1], model)

# Update output folder path
outputFolder = self.ui.outputDirectoryButton.directory
Expand All @@ -522,6 +527,12 @@ def onClearScanConversion(self):
settings = qt.QSettings()
settings.setValue(self.logic.LAST_SCAN_CONVERSION_PATH_SETTING, "")
self.logic.loadScanConversion(None)

def onErodeEdgeX(self, value):
self.logic.erodeCurvilinearMask(value, self.ui.edgeErosionYSpinBox.value)

def onErodeEdgeY(self, value):
self.logic.erodeCurvilinearMask(self.ui.edgeErosionXSpinBox.value, value)

def onModelSelectionMethodChanged(self, caller=None, event=None):
useIndividualModel = self.ui.useIndividualRadioButton.checked
Expand Down Expand Up @@ -778,7 +789,7 @@ def setDefaultParameters(self, parameterNode):
def getAllModelPaths(self):
modelFolder = slicer.util.settingsValue(self.LAST_MODEL_FOLDER_SETTING, "")
if modelFolder:
models = glob.glob(os.path.join(modelFolder, "**", "*traced*.pt"), recursive=True)
models = glob.glob(os.path.join(modelFolder, "**", "*.pt"), recursive=True)
normModels = [os.path.normpath(model) for model in models] # normalize paths
return normModels
else:
Expand Down Expand Up @@ -870,11 +881,25 @@ def loadScanConversion(self, scanConversionPath):
self.curvilinear_mask = cv2.circle(self.curvilinear_mask,
(center_coordinate_pixel[1], center_coordinate_pixel[0]),
radius_start_px + 1, 0, -1)
self.curvilinear_mask = self.curvilinear_mask.astype(np.uint8) # Convert mask_array to uint8

def scanConvert(self, linearArray):
z = linearArray.flatten()
zi = np.einsum("ij,ij->i", np.take(z, self.vertices), self.weights)
return zi.reshape(self.curvilinear_size, self.curvilinear_size)

def erodeCurvilinearMask(self, edgeErosionX, edgeErosionY):
# Erode mask by 10 percent of the image size to remove artifacts on the edges
if self.curvilinear_mask is not None and (edgeErosionX > 0 or edgeErosionY > 0):
# Repaint the borders of the mask to zero to allow erosion from all sides
self.curvilinear_mask[0, :] = 0
self.curvilinear_mask[:, 0] = 0
self.curvilinear_mask[-1, :] = 0
self.curvilinear_mask[:, -1] = 0
# Erode the mask
erosionSizeX = int(edgeErosionX * self.curvilinear_size)
erosionSizeY = int(edgeErosionY * self.curvilinear_size)
self.curvilinear_mask = cv2.erode(self.curvilinear_mask, np.ones((erosionSizeX, erosionSizeY), np.uint8), iterations=1)

def getUniqueName(self, node, baseName):
newName = baseName
Expand Down Expand Up @@ -952,7 +977,7 @@ def getPrediction(self, image):
if inputArray.max() <= 1.0:
logging.info("Input image is already between 0 and 1, skipping normalization.")
else:
inputArray = inputArray.astype(float) / inputArray.max()
inputArray = inputArray.astype(float) / 255.0

# Convert to tensor and add batch dimension
inputTensor = torch.from_numpy(inputArray).unsqueeze(0).unsqueeze(0).float().to(DEVICE)
Expand Down

0 comments on commit 491384f

Please sign in to comment.