Skip to content

Commit

Permalink
Added checkbox and setting for normalizing input image
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Feb 13, 2024
1 parent dc4de49 commit cddb28c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<x>0</x>
<y>0</y>
<width>395</width>
<height>720</height>
<height>775</height>
</rect>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
Expand Down Expand Up @@ -511,14 +511,14 @@
</property>
</widget>
</item>
<item row="3" column="0">
<item row="4" column="0">
<widget class="QLabel" name="label_8">
<property name="text">
<string>Model input size:</string>
</property>
</widget>
</item>
<item row="3" column="1">
<item row="4" column="1">
<widget class="QSpinBox" name="modelInputSizeSpinbox">
<property name="toolTip">
<string>Size of the input image of the loaded model. Assumes a square image. Only modify if shape metadata is not included in the TorchScript model.</string>
Expand All @@ -528,14 +528,14 @@
</property>
</widget>
</item>
<item row="4" column="0">
<item row="5" column="0">
<widget class="QLabel" name="label_11">
<property name="text">
<string>Output transform:</string>
</property>
</widget>
</item>
<item row="4" column="1">
<item row="5" column="1">
<widget class="qMRMLNodeComboBox" name="outputTransformSelector">
<property name="enabled">
<bool>true</bool>
Expand All @@ -562,14 +562,14 @@
</property>
</widget>
</item>
<item row="6" column="0">
<item row="7" column="0">
<widget class="QLabel" name="label_4">
<property name="text">
<string>Scan conversion config:</string>
</property>
</widget>
</item>
<item row="6" column="1">
<item row="7" column="1">
<layout class="QHBoxLayout" name="horizontalLayout_5">
<item>
<widget class="ctkPathLineEdit" name="scanConversionPathLineEdit">
Expand Down Expand Up @@ -618,6 +618,20 @@
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QLabel" name="label_18">
<property name="text">
<string>Normalize input to [0,1]:</string>
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QCheckBox" name="normalizeCheckBox">
<property name="text">
<string/>
</property>
</widget>
</item>
</layout>
</widget>
</item>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def setup(self):
self.ui.modelInputSizeSpinbox.connect("valueChanged(int)", self.updateParameterNodeFromGUI)
self.ui.applyLogCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI)

lastNormalizeSetting = slicer.util.settingsValue(self.logic.LAST_NORMALIZE_SETTING, False, converter=slicer.util.toBool)
self.ui.normalizeCheckBox.checked = lastNormalizeSetting
self.ui.normalizeCheckBox.connect("toggled(bool)", self.updateSettingsFromGUI)

# File paths
# Set last model folder in UI
lastModelFolder = slicer.util.settingsValue(self.logic.LAST_MODEL_FOLDER_SETTING, "")
Expand Down Expand Up @@ -507,6 +511,11 @@ def updateSettingsFromGUI(self, caller=None, event=None):
outputFolder = self.ui.outputDirectoryButton.directory
if outputFolder != slicer.util.settingsValue(self.logic.LAST_OUTPUT_FOLDER_SETTING, ""):
settings.setValue(self.logic.LAST_OUTPUT_FOLDER_SETTING, outputFolder)

# Update normalize setting
normalizeInput = self.ui.normalizeCheckBox.checked
if normalizeInput != slicer.util.settingsValue(self.logic.LAST_NORMALIZE_SETTING, "", converter=slicer.util.toBool):
settings.setValue(self.logic.LAST_NORMALIZE_SETTING, normalizeInput)

def onClearScanConversion(self):
self.ui.scanConversionPathLineEdit.currentPath = ""
Expand Down Expand Up @@ -730,6 +739,7 @@ class TorchSequenceSegmentationLogic(ScriptedLoadableModuleLogic):
https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py
"""

LAST_NORMALIZE_SETTING = "TorchSequenceSegmentation/NormalizeInput"
LAST_MODEL_FOLDER_SETTING = "TorchSequenceSegmentation/LastModelFolder"
LAST_SCAN_CONVERSION_PATH_SETTING = "TorchSequenceSegmentation/LastScanConversionPath"
LAST_OUTPUT_FOLDER_SETTING = "TorchSequenceSegmentation/LastOutputFolder"
Expand Down Expand Up @@ -937,8 +947,12 @@ def getPrediction(self, image):
inputArray = np.flip(inputArray, axis=0)

# Normalize input if needed
if inputArray.max() > 1.0:
inputArray = inputArray.astype(float) / inputArray.max()
normalizeInput = slicer.util.settingsValue(self.LAST_NORMALIZE_SETTING, False, converter=slicer.util.toBool)
if normalizeInput:
if inputArray.max() <= 1.0:
logging.info("Input image is already between 0 and 1, skipping normalization.")
else:
inputArray = inputArray.astype(float) / inputArray.max()

# Convert to tensor and add batch dimension
inputTensor = torch.from_numpy(inputArray).unsqueeze(0).unsqueeze(0).float().to(DEVICE)
Expand Down

0 comments on commit cddb28c

Please sign in to comment.