From ce13a7f426dd307c70d631b95ed70dd2c85dd695 Mon Sep 17 00:00:00 2001 From: chriscyyeung Date: Tue, 30 Jul 2024 14:08:19 -0400 Subject: [PATCH] Added support for using previous ultrasound frames as model input in TorchSequenceSegmentation --- .../Resources/UI/TorchSequenceSegmentation.ui | 139 +++++++----- .../TorchSequenceSegmentation.py | 202 ++++++++++++------ 2 files changed, 217 insertions(+), 124 deletions(-) diff --git a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui index 05f29f8..13331bc 100644 --- a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui +++ b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/Resources/UI/TorchSequenceSegmentation.ui @@ -7,7 +7,7 @@ 0 0 395 - 775 + 988 @@ -269,49 +269,43 @@ - - - - - - 0 - 0 - - - - Reconstruct the segmentation and render in 3D. - - - Qt::RightToLeft - - - Reconstruct 3D volume: - - - true - - - - - - - false - - - - 0 - 0 - - - - Start segmentation and/or reconstruction of the ultrasound sequence. - - - Start - - - - + + + Record prediction as segmentation + + + true + + + + + + + Reconstruct 3D volume + + + true + + + + + + + false + + + + 0 + 0 + + + + Start segmentation and/or reconstruction of the ultrasound sequence. + + + Start + + @@ -542,14 +536,14 @@ - + Model input size: - + Size of the input image of the loaded model. Assumes a square image. Only modify if shape metadata is not included in the TorchScript model. @@ -559,14 +553,14 @@ - + Output transform: - + true @@ -593,14 +587,14 @@ - + Scan conversion config: - + @@ -635,14 +629,14 @@ - + Mask edge erosion x (%): - + 2 @@ -655,14 +649,14 @@ - + Mask edge erosion y (%): - + 1.000000000000000 @@ -672,6 +666,43 @@ + + + + Segmentation threshold: + + + + + + + 255 + + + 127 + + + + + + + Segment name: + + + + + + + + + + Number of previous frames: + + + + + + diff --git a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py index 0a4b8e6..aa7bc26 100644 --- a/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py +++ b/SlicerExtension/LiveUltrasoundAi/TorchSequenceSegmentation/TorchSequenceSegmentation.py @@ -223,8 +223,10 @@ def setup(self): self.ui.verticalFlipCheckbox.connect("toggled(bool)", self.updateParameterNodeFromGUI) self.ui.modelInputSizeSpinbox.connect("valueChanged(int)", self.updateParameterNodeFromGUI) self.ui.applyLogCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI) - self.ui.edgeErosionXSpinBox.connect("valueChanged(float)", self.onErodeEdgeX) - self.ui.edgeErosionYSpinBox.connect("valueChanged(float)", self.onErodeEdgeY) + self.ui.edgeErosionXSpinBox.connect("valueChanged(double)", self.onErodeEdgeX) + self.ui.edgeErosionYSpinBox.connect("valueChanged(double)", self.onErodeEdgeY) + self.ui.segmentNameLineEdit.connect("textChanged(const QString)", self.updateParameterNodeFromGUI) + self.ui.thresholdSpinBox.connect("valueChanged(int)", self.updateParameterNodeFromGUI) lastNormalizeSetting = slicer.util.settingsValue(self.logic.LAST_NORMALIZE_SETTING, False, converter=slicer.util.toBool) self.ui.normalizeCheckBox.checked = lastNormalizeSetting @@ -431,6 +433,12 @@ def updateGUIFromParameterNode(self, caller=None, event=None): modelInputSize = self._parameterNode.GetParameter("ModelInputSize") self.ui.modelInputSizeSpinbox.setValue(int(modelInputSize) if modelInputSize else 0) + segmentName = self._parameterNode.GetParameter("SegmentName") + self.ui.segmentNameLineEdit.setText(segmentName) + + threshold = self._parameterNode.GetParameter("Threshold") + self.ui.thresholdSpinBox.setValue(int(threshold) if threshold else 0) + # Change output transform to parent of input volume if inputVolume: inputVolumeParent = inputVolume.GetParentTransformNode() @@ -443,7 +451,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.outputTransformSelector.blockSignals(wasBlocked) # Enable/disable buttons - if self.ui.reconstructCheckBox.checked: + if self.ui.reconstructButton.checked: self.ui.startButton.setEnabled(sequenceBrowser and inputVolume and volumeReconstructionNode @@ -480,6 +488,8 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetParameter("FlipVertical", "true" if self.ui.verticalFlipCheckbox.checked else "false") self._parameterNode.SetParameter("ApplyLogTransform", "true" if self.ui.applyLogCheckBox.checked else "false") self._parameterNode.SetParameter("ModelInputSize", str(self.ui.modelInputSizeSpinbox.value)) + self._parameterNode.SetParameter("SegmentName", self.ui.segmentNameLineEdit.text) + self._parameterNode.SetParameter("Threshold", str(self.ui.thresholdSpinBox.value)) # Update edge erosion parameters self.logic.erodeCurvilinearMask(self.ui.edgeErosionXSpinBox.value, self.ui.edgeErosionYSpinBox.value) @@ -605,7 +615,8 @@ def onStartButton(self): self.ui.sequenceBrowserSelector.setEnabled(False) self.ui.inputVolumeSelector.setEnabled(False) self.ui.volumeReconstructionSelector.setEnabled(False) - self.ui.reconstructCheckBox.setEnabled(False) + self.ui.reconstructButton.setEnabled(False) + self.ui.recordAsSegmentationButton.setEnabled(False) self.ui.verticalFlipCheckbox.setEnabled(False) self.ui.applyLogCheckBox.setEnabled(False) @@ -616,7 +627,7 @@ def onStartButton(self): # Overall progress bar numModels = len(self.logic.getModelsToUse()) - progressMax = numModels * 2 if self.ui.reconstructCheckBox.checked else numModels + progressMax = numModels * 2 if self.ui.reconstructButton.checked else numModels self.ui.overallProgressBar.setMaximum(progressMax) slicer.app.processEvents() @@ -641,10 +652,10 @@ def onStartButton(self): self._parameterNode.SetNodeReferenceID("SequenceBrowser", sequenceBrowser.GetID()) numFrames = sequenceBrowser.GetMasterSequenceNode().GetNumberOfDataNodes() - 1 self.setPredictionProgressBar(numFrames) - self.logic.segmentSequence(model) + self.logic.segmentSequence(model, self.ui.recordAsSegmentationButton.checked, int(self.ui.previousFramesSpinBox.value)) self.resetTaskProgressBar() - if self.ui.reconstructCheckBox.checked: + if self.ui.reconstructButton.checked: self.ui.overallProgressBar.setValue(self.ui.overallProgressBar.value + 1) self.ui.taskStatusLabel.setText("Reconstructing volume...") self.setReconstructionProgressBar() @@ -675,7 +686,8 @@ def onStartButton(self): self.ui.sequenceBrowserSelector.setEnabled(True) self.ui.inputVolumeSelector.setEnabled(True) self.ui.volumeReconstructionSelector.setEnabled(True) - self.ui.reconstructCheckBox.setEnabled(True) + self.ui.reconstructButton.setEnabled(True) + self.ui.recordAsSegmentationButton.setEnabled(True) self.ui.verticalFlipCheckbox.setEnabled(True) self.ui.applyLogCheckBox.setEnabled(True) @@ -785,6 +797,10 @@ def setDefaultParameters(self, parameterNode): parameterNode.SetParameter("UseAllBrowsers", "false") if not parameterNode.GetParameter("Invert"): parameterNode.SetParameter("Invert", "false") + if not parameterNode.GetParameter("SegmentName"): + parameterNode.SetParameter("SegmentName", "Segmentation") + if not parameterNode.GetParameter("Threshold"): + parameterNode.SetParameter("Threshold", "127") def getAllModelPaths(self): modelFolder = slicer.util.settingsValue(self.LAST_MODEL_FOLDER_SETTING, "") @@ -918,54 +934,13 @@ def getUniqueName(self, node, baseName): def getUseAllBrowsers(self): parameterNode = self.getParameterNode() return parameterNode.GetParameter("UseAllBrowsers") == "true" - - def addPredictionVolume(self, modelName): - parameterNode = self.getParameterNode() - - # Make new prediction volume to not overwrite existing one - predictionVolume = parameterNode.GetNodeReference("PredictionVolume") - volumeName = self.getUniqueName(predictionVolume, f"{modelName.split(os.sep)[-2]}_Prediction") - predictionVolume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode", volumeName) - predictionVolume.CreateDefaultDisplayNodes() - parameterNode.SetNodeReferenceID("PredictionVolume", predictionVolume.GetID()) - - # Place in output transform if it exists - outputTransform = parameterNode.GetNodeReference("OutputTransform") - if outputTransform: - predictionVolume.SetAndObserveTransformNodeID(outputTransform.GetID()) - - return predictionVolume - - def addPredictionSequenceNode(self, predictionVolume, modelName): - parameterNode = self.getParameterNode() - sequenceBrowser = parameterNode.GetNodeReference("SequenceBrowser") - # Add a new sequence node to the sequence browser - masterSequenceNode = sequenceBrowser.GetMasterSequenceNode() - sequenceName = self.getUniqueName(masterSequenceNode, f"{modelName.split(os.sep)[-2]}_PredictionSequence") - predictionSequenceNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSequenceNode", sequenceName) - sequenceBrowser.AddSynchronizedSequenceNode(predictionSequenceNode) - sequenceBrowser.AddProxyNode(predictionVolume, predictionSequenceNode, False) - - return predictionSequenceNode - - def getPrediction(self, image): + def getPrediction(self, inputArray): if not self.model: return - - if not image.GetImageData(): - return - - imageArray = slicer.util.arrayFromVolume(image) + parameterNode = self.getParameterNode() - # Use inverse scan conversion if specified by user, otherwise resize - if self.scanConversionDict: - inputArray = map_coordinates(imageArray[0, :, :], [self.cart_x, self.cart_y], order=1) - else: - inputSize = int(parameterNode.GetParameter("ModelInputSize")) - inputArray = cv2.resize(imageArray[0, :, :], (inputSize, inputSize)) # default is bilinear - # Flip image vertically if specified by user toFlip = parameterNode.GetParameter("FlipVertical").lower() == "true" if toFlip: @@ -980,7 +955,7 @@ def getPrediction(self, image): inputArray = inputArray.astype(float) / 255.0 # Convert to tensor and add batch dimension - inputTensor = torch.from_numpy(inputArray).unsqueeze(0).unsqueeze(0).float().to(DEVICE) + inputTensor = torch.from_numpy(inputArray).unsqueeze(0).float().to(DEVICE) # Run prediction with torch.inference_mode(): @@ -996,13 +971,6 @@ def getPrediction(self, image): if toFlip: outputArray = np.flip(outputArray, axis=0) - # Scan convert or resize - if self.scanConversionDict: - outputArray = self.scanConvert(outputArray) - outputArray *= self.curvilinear_mask - else: - outputArray = cv2.resize(outputArray, (imageArray.shape[2], imageArray.shape[1])) - if parameterNode.GetParameter("ApplyLogTransform").lower() == "true": e = self.LOGARITHMIC_TRANSFORMATION_DECIMALS outputArray = np.log10(np.clip(outputArray, 10 ** (-e), 1.0) * (10 ** e)) / e @@ -1013,35 +981,129 @@ def getPrediction(self, image): return outputArray - def segmentSequence(self, modelName): + def segmentSequence(self, modelName, recordAsSegmentation=False, numPreviousFrames=0): self.isProcessing = True parameterNode = self.getParameterNode() sequenceBrowser = parameterNode.GetNodeReference("SequenceBrowser") inputVolume = parameterNode.GetNodeReference("InputVolume") inputSequence = sequenceBrowser.GetSequenceNode(inputVolume) + modelBasename = modelName.split(os.sep)[-2] + segmentName = parameterNode.GetParameter("SegmentName") + threshold = int(parameterNode.GetParameter("Threshold")) - # Create prediction sequence - predictionVolume = self.addPredictionVolume(modelName) - predictionSequenceNode = self.addPredictionSequenceNode(predictionVolume, modelName) + # Make new prediction volume to not overwrite existing one + predictionVolume = parameterNode.GetNodeReference("PredictionVolume") + volumeName = self.getUniqueName(predictionVolume, f"{modelBasename}_Prediction") + predictionVolume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode", volumeName) + predictionVolume.CreateDefaultDisplayNodes() + parameterNode.SetNodeReferenceID("PredictionVolume", predictionVolume.GetID()) + + # Add a new sequence node to the sequence browser + masterSequenceNode = sequenceBrowser.GetMasterSequenceNode() + sequenceName = self.getUniqueName(masterSequenceNode, f"{modelBasename}_PredictionSequence") + predictionSequenceNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSequenceNode", sequenceName) + sequenceBrowser.AddSynchronizedSequenceNode(predictionSequenceNode) + sequenceBrowser.AddProxyNode(predictionVolume, predictionSequenceNode, False) + + # Add segmentation node from prediction volume + if recordAsSegmentation: + segmentationNode = parameterNode.GetNodeReference("Segmentation") + segmentationName = self.getUniqueName(segmentationNode, f"{modelBasename}_Segmentation") + segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode", segmentationName) + segmentationNode.CreateDefaultDisplayNodes() + segmentationNode.GetDisplayNode().SetVisibility(False) + segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(inputVolume) + segmentationNode.GetSegmentation().AddEmptySegment(segmentName) + ids = vtk.vtkStringArray() + ids.InsertNextValue(segmentName) + parameterNode.SetNodeReferenceID("Segmentation", segmentationNode.GetID()) + + # Add segmentation node to sequence browser + segSeqName = self.getUniqueName(masterSequenceNode, f"{modelBasename}_SegmentationSequence") + segSequenceNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSequenceNode", segSeqName) + sequenceBrowser.AddSynchronizedSequenceNode(segSequenceNode) + sequenceBrowser.AddProxyNode(segmentationNode, segSequenceNode, False) + sequenceBrowser.SetSaveChanges(segSequenceNode, True) + + # Place in output transform if it exists + outputTransform = parameterNode.GetNodeReference("OutputTransform") + if outputTransform: + predictionVolume.SetAndObserveTransformNodeID(outputTransform.GetID()) + if recordAsSegmentation: + segmentationNode.SetAndObserveTransformNodeID(outputTransform.GetID()) # Overlay prediction volume in slice view predictionDisplayNode = predictionVolume.GetDisplayNode() predictionDisplayNode.SetAndObserveColorNodeID("vtkMRMLColorTableNodeGreen") slicer.util.setSliceViewerLayers(foreground=predictionVolume, foregroundOpacity=0.3) - # Iterate through each item in sequence browser and add generated segmentation + # create list for previous frame buffer + if numPreviousFrames > 0: + frameBufferList = [] + selectedItemNumber = sequenceBrowser.GetSelectedItemNumber() # for restoring later + # Iterate through each item in sequence browser and add generated segmentation for itemIndex in range(sequenceBrowser.GetNumberOfItems()): - # Generate segmentation - currentImage = inputSequence.GetNthDataNode(itemIndex) + # Get current frame + image = inputSequence.GetNthDataNode(itemIndex) + imageArray = slicer.util.arrayFromVolume(image) + originalImageShape = imageArray.shape sequenceBrowser.SetSelectedItemNumber(itemIndex) - prediction = self.getPrediction(currentImage) - slicer.util.updateVolumeFromArray(predictionVolume, prediction) - # Add segmentation to sequence browser - indexValue = inputSequence.GetNthIndexValue(itemIndex) + # Use inverse scan conversion if specified by user, otherwise resize + if self.scanConversionDict: + imageArray = map_coordinates(imageArray[0, :, :], [self.cart_x, self.cart_y], order=1) + else: + inputSize = int(parameterNode.GetParameter("ModelInputSize")) + imageArray = cv2.resize(imageArray[0, :, :], (inputSize, inputSize)) # default is bilinear + + # create numpy array from frame buffer + if numPreviousFrames > 0: + if itemIndex == 0: + frameBufferList.append(imageArray) + frameBufferList *= numPreviousFrames + 1 + inputArray = np.stack(frameBufferList, axis=0) + else: + inputArray = np.expand_dims(imageArray, axis=0) + + # Generate segmentation + prediction = self.getPrediction(inputArray) + + # Scan convert or resize + if self.scanConversionDict: + prediction = self.scanConvert(prediction) + prediction *= self.curvilinear_mask + else: + prediction = cv2.resize(prediction, (originalImageShape[2], originalImageShape[1])) + + slicer.util.updateVolumeFromArray(predictionVolume, prediction) + indexValue = masterSequenceNode.GetNthIndexValue(itemIndex) predictionSequenceNode.SetDataNodeAtValue(predictionVolume, indexValue) + + if recordAsSegmentation: + # Create temporary label map + labelmapVolume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode") + slicer.modules.volumes.logic().CreateLabelVolumeFromVolume(slicer.mrmlScene, labelmapVolume, predictionVolume) + + # Fill label map by thresholding prediction + labelmapArray = slicer.util.arrayFromVolume(labelmapVolume) + labelmapArray[prediction < threshold] = 0 + labelmapArray[prediction >= threshold] = 1 + slicer.util.arrayFromVolumeModified(labelmapVolume) + + # Import label map to segmentation + slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(labelmapVolume, segmentationNode, ids) + + # Add segmentation to sequence browser + segSequenceNode.SetDataNodeAtValue(segmentationNode, indexValue) + slicer.mrmlScene.RemoveNode(labelmapVolume) + + # dequeue first frame and enqueue current frame + if numPreviousFrames > 0: + frameBufferList.pop() + frameBufferList.insert(0, imageArray) + if self.progressCallback: self.progressCallback(itemIndex) sequenceBrowser.SetSelectedItemNumber(selectedItemNumber)