Skip to content

Commit

Permalink
re #24: Switched the save location selector to a button to work better
Browse files Browse the repository at this point in the history
re #27: reformatted file names, created version 2 of the SOP
re #28: The script now works with the current model output
  • Loading branch information
16djm10 committed Nov 22, 2022
1 parent 931cf29 commit 336f202
Show file tree
Hide file tree
Showing 4 changed files with 516 additions and 953 deletions.
119 changes: 59 additions & 60 deletions BroadbandSpecModule/BroadbandSpecModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class BroadbandSpecModule(ScriptedLoadableModule):

def __init__(self, parent):
ScriptedLoadableModule.__init__(self, parent)
self.parent.title = "BroadbandSpecModule" # TODO: make this more human readable by adding spaces
self.parent.categories = ["Broadband"]
self.parent.title = "Broadband Spectroscopy" # TODO: make this more human readable by adding spaces
self.parent.categories = ["Spectroscopy"]
self.parent.dependencies = []
parent.contributors = ["David Morton (Queen's University, PERK Lab)"]
self.parent.helpText = """
Expand Down Expand Up @@ -94,6 +94,10 @@ def setup(self):
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
self.logic = BroadbandSpecModuleLogic()
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
self.initializeScene()
self.logic.setupLists()

# Connections

Expand All @@ -119,18 +123,22 @@ def setup(self):
self.ui.clearLastPointButton.connect('clicked(bool)', self.onClearLastPointButtonClicked)
# Data Collection
self.ui.dataClassSelector.connect('currentIndexChanged(int)', self.onDataClassSelectorChanged)
self.ui.saveLocationSelector.connect('currentPathChanged(QString)', self.onSaveLocationSelectorChanged)
# add the options cancer and normal to the data class selector
self.ui.dataClassSelector.addItem("Cancer")
self.ui.dataClassSelector.addItem("Normal")
self.ui.patientNumberSelector.connect('currentIndexChanged(int)', self.onPatientNumberSelectorChanged)
# add the options Patient1 the patient number selector
self.ui.patientNumberSelector.addItem("Patient1")
# self.ui.saveLocationSelector.connect('currentPathChanged(QString)', self.onSaveLocationSelectorChanged)
self.ui.saveDirectoryButton.connect('directorySelected(QString)', self.onSaveDirectoryButtonClicked)
# update the saveDirectory using parameter node to the self.logic.SAVE_LOCATION
self.ui.saveDirectoryButton.directory = self._parameterNode.GetParameter(self.logic.SAVE_LOCATION)

self.ui.samplingDurationSlider.connect('valueChanged(double)', self.onSamplingDurationChanged)
self.ui.samplingRateSlider.connect('valueChanged(double)', self.onSamplingRateChanged)
self.ui.collectSampleButton.connect('clicked(bool)', self.onCollectSampleButtonClicked)
self.ui.continuousCollectionButton.connect('clicked(bool)', self.onContinuousCollectionButtonClicked)
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
self.initializeScene()
self.logic.setupLists()


def initializeScene(self):
# NeedleModel
Expand All @@ -145,11 +153,6 @@ def initializeScene(self):
# Add it to parameter node
self._parameterNode.SetNodeReferenceID(self.logic.NEEDLE_MODEL, needleModel.GetID())

# # If not already transformed, add it to the NeedleToRas transform
# needleModelTransform = needleModel.GetParentTransformNode()
# if needleModelTransform is None:
# needleModel.SetAndObserveTransformNodeID(needleToRasTransform.GetID())

# NeedleTip pointlist

# If pointList_NeedleTip is not in the scene, create and add it
Expand All @@ -165,10 +168,6 @@ def initializeScene(self):
if pointList_EMT.GetNumberOfControlPoints() == 0:
pointList_EMT.AddControlPoint(np.array([0, 0, 0]))
pointList_EMT.SetNthControlPointLabel(0, "origin_Tip")
# # If not already transformed, add it to the NeedleToRas transform
# pointList_NeedleTipTransform = pointList_NeedleTip.GetParentTransformNode()
# if pointList_NeedleTipTransform is None:
# pointList_NeedleTip.SetAndObserveTransformNodeID(needleToRasTransform.GetID())
pass

# My functions
Expand All @@ -188,6 +187,12 @@ def onDataClassSelectorChanged(self):
dataClass = self.ui.dataClassSelector.currentText
parameterNode.SetParameter(self.logic.DATA_CLASS, dataClass)

def onPatientNumberSelectorChanged(self):
self.updateParameterNodeFromGUI()
parameterNode = self.logic.getParameterNode()
patientNumber = self.ui.patientNumberSelector.currentText
parameterNode.SetParameter(self.logic.PATIENT_NUM, patientNumber)

def onSamplingDurationChanged(self):
self.updateParameterNodeFromGUI()
parameterNode = self.logic.getParameterNode()
Expand Down Expand Up @@ -251,14 +256,17 @@ def onConnectButtonClicked(self):
pointList_NeedleTip = parameterNode.GetNodeReference(self.logic.POINTLIST_EMT)
pointList_NeedleTip.SetAndObserveTransformNodeID(transformNode.GetID())

def onSaveLocationSelectorChanged(self):
self.updateParameterNodeFromGUI()
# get path from the selector
path = self.ui.saveLocationSelector.currentPath
# get parameter node
parameterNode = self.logic.getParameterNode()
# set the path in the parameter node
parameterNode.SetParameter(self.logic.SAVE_LOCATION, path)

def onSaveDirectoryButtonClicked(self):
self.updateParameterNodeFromGUI()
# get the path from the button
path = self.ui.saveDirectoryButton.directory
# Print the save directory
print('Save directory: ' + path)
# Save that path to the parameter node
parameterNode = self.logic.getParameterNode()
parameterNode.SetParameter(self.logic.SAVE_LOCATION, path)
# pass

def onSpectrumImageChanged(self):
self.updateParameterNodeFromGUI()
Expand Down Expand Up @@ -305,7 +313,6 @@ def setEnableClassification(self, enable):
def onClearLastPointButtonClicked(self):
self.updateParameterNodeFromGUI()
# Check to see if the lists exist, and if not create them
# self.logic.setupLists()
print("This button is not currently implemented")
pass

Expand All @@ -316,7 +323,6 @@ def onClearControlPointsButtonClicked(self):
def onAddControlPointButtonClicked(self):
self.updateParameterNodeFromGUI()
# Check to see if the lists exist, and if not create them
# self.logic.setupLists()
self.logic.addControlPointToToolTip()

def onScanButtonClicked(self, checked):
Expand Down Expand Up @@ -359,7 +365,6 @@ def initializeParameterNode(self):
self.setParameterNode(self.logic.getParameterNode())

# Ensure the required lists are created and reference in the parameter node
# self.logic.setupLists()

# Select default input nodes if nothing is selected yet to save a few clicks for the user
if not self._parameterNode.GetNodeReference(self.logic.INPUT_VOLUME):
Expand Down Expand Up @@ -416,11 +421,6 @@ def updateGUIFromParameterNode(self, caller=None, event=None):
# set the current path to whatever is stored in the parameter node
self.ui.modelFileSelector.currentPath = self._parameterNode.GetParameter(self.logic.MODEL_PATH)

# Update the save location to be the last selection
if self.ui.saveLocationSelector.currentPath == '':
# set the current path to whatever is stored in the parameter node
self.ui.saveLocationSelector.currentPath = self._parameterNode.GetParameter(self.logic.SAVE_LOCATION)

# All the GUI updates are done
self._updatingGUIFromParameterNode = False

Expand All @@ -441,7 +441,6 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
parameterNode.SetParameter(self.logic.SCANNING_STATE, str(self.ui.scanButton.isChecked()))
# update parameter node with the state of enable plotting button
parameterNode.SetParameter(self.logic.PLOTTING_STATE, str(self.ui.enablePlottingButton.isChecked()))
# print("Plotting state: " + str(self.ui.enablePlottingButton.isChecked()))
# update parameter node with the state of enable classification button
parameterNode.SetParameter(self.logic.CLASSIFYING_STATE, str(self.ui.enableClassificationButton.isChecked()))
# update parameter node with the current path of the file selector
Expand Down Expand Up @@ -510,6 +509,7 @@ class BroadbandSpecModuleLogic(ScriptedLoadableModuleLogic,VTKObservationMixin):
SAMPLING_DURATION = "Sample Duration" # Parameter stores the duration of the sampling
SAMPLING_RATE = "Sample Rate" # Parameter stores the rate of the sampling
DATA_CLASS = "Data Class" # Parameter stores the data class we are recording
PATIENT_NUM = "Patient Number" # Parameter stores the patient number
SAVE_LOCATION = 'Save Location' # Parameter stores the location where the data is saved

# ROLES
Expand All @@ -530,8 +530,9 @@ class BroadbandSpecModuleLogic(ScriptedLoadableModuleLogic,VTKObservationMixin):
CLASS_LABEL_1 = "ClassLabel1" # The label of the second class
CLASS_LABEL_NONE = "WeakSignal" # The label of the class when the signal is too weak
DISTANCE_THRESHOLD = 1 # in mm
DEFAULT_SAVE_LOCATION = os.path.join('C:\Spectroscopy_TrackedTissueSensing\data', 'Nov2022_skinTestData')
# DEFAULT_SAVE_LOCATION = os.path.join('C:\Spectroscopy_TrackedTissueSensing\data', 'Nov2022_skinTestData')
DEFAULT_MODEL_PATH = os.path.join('C:\Spectroscopy_TrackedTissueSensing\TrainedModels', 'KNN_WhiteVsBlue2.joblib')
DEFAULT_SAVE_LOCATION = os.path.join('C:\Spectroscopy_TrackedTissueSensing\data', 'SkinDataCollection')

def __init__(self):
"""
Expand Down Expand Up @@ -560,9 +561,8 @@ def setDefaultParameters(self, parameterNode):
parameterNode.SetParameter(self.CLASSIFICATION, '')
# if the self.model path is not set, grab it from the ctkPathLineEdit widget
if parameterNode.GetParameter(self.MODEL_PATH) == '':
parameterNode.SetParameter(self.MODEL_PATH, 'C:/Spectroscopy_TrackedTissueSensing/TrainedModels/KNN_PorkVsBeefTest.joblib') # Hardcoded path
parameterNode.SetParameter(self.MODEL_PATH, self.DEFAULT_MODEL_PATH)
if self.model == None:
# parameterNode = self.getParameterNode()
modelPath = parameterNode.GetParameter(self.MODEL_PATH)
self.model = load(modelPath)

Expand Down Expand Up @@ -602,7 +602,7 @@ def addObservers(self):
parameterNode = self.getParameterNode()
spectrumImageNode = parameterNode.GetNodeReference(self.INPUT_VOLUME)
if spectrumImageNode:
print("Add observer to {0}".format(spectrumImageNode.GetName()))
# print("Add observer to {0}".format(spectrumImageNode.GetName()))
self.observerTags.append([spectrumImageNode, spectrumImageNode.AddObserver(vtk.vtkCommand.ModifiedEvent, self.onSpectrumImageNodeModified)])

def removeObservers(self):
Expand Down Expand Up @@ -810,8 +810,6 @@ def saveSample(self):
sampleDuration = parameterNode.GetParameter(self.SAMPLING_DURATION)
# Stop the timer
browserNode = parameterNode.GetNodeReference(self.SAMPLE_SEQ_BROWSER)
# browserNode.SetRecordingActive(False)
# print("Recording stopped")
# Get the sequence node
sequenceNode = parameterNode.GetNodeReference(self.SAMPLE_SEQUENCE)
# Save the sequence to a csv
Expand All @@ -824,45 +822,46 @@ def saveSample(self):
# Check to see if any data has been recorded
if sequenceLength == 0:
print("No data to save")
# Get the number of files in the folder
# numFiles = len([name for name in os.listdir(savePath) if os.path.isfile(os.path.join(savePath, name))])
# print("Number of files in folder: " + str(numFiles))
# Get the path to the current folder
currentPath = os.getcwd()
print(currentPath)
# add the data folder
dataPath = os.path.join(currentPath, "Data")
currentPath = os.join
return

# print("Sequence length: " + str(sequenceLength))
# Format the empty array
spectrumArray = slicer.util.arrayFromVolume(sequenceNode.GetNthDataNode(0)) # Get the length of a spectrum
SpectrumLength = spectrumArray.shape[2]
spectrumArray2D = np.zeros((sequenceLength + 1, SpectrumLength + 1)) # Create the 2D array
timeVector = np.linspace(0, float(sampleDuration), sequenceLength) # create a time vector using the sampleDuration
spectrumArray2D[1:,0] = timeVector # concatenate the time vector to the spectrum array
# print shape of the array
# print(spectrumArray2D.shape)
spectrumArray2D = np.zeros((sequenceLength + 1, SpectrumLength + 1)) # Create the 2D array, add 1 for wavelength and time
timeVector = np.linspace(0, float(sampleDuration), sequenceLength) # create a time vector using the sampleDuration
spectrumArray2D[1:,0] = timeVector # concatenate the time vector to the spectrum array
waveLengthVector = spectrumArray[0,0,:]
# print(waveLengthVector.shape)
spectrumArray2D[0,1:] = waveLengthVector
# print(spectrumArray2D)

for i in range(sequenceLength):
# Get a spectrum as an array
spectrumArray = np.squeeze(slicer.util.arrayFromVolume(sequenceNode.GetNthDataNode(i)))
spectrumArray2D[i+1,1:] = spectrumArray[1,:]
# Save the array to a csv
# np.savetxt(savePath + '.csv', spectrumArray2D[1:,:], delimiter=",")
numFiles = len([name for name in os.listdir(savePath) if os.path.isfile(os.path.join(savePath, name))]).zfill(3)
'''File naming convention: DataLabel_#ofFiles_TimeStamp.csv with TimeStamp in the format of YYYY-MM-DD_HH-MM-SS'''
# timeStamp = time.strftime("%Y-%m-%d_%H-%M-%S")
# fileName = dataLabel + '_' + str(numFiles) + '_' + timeStamp + '.csv'
fileName = dataLabel + '_' + str(numFiles) + '.csv'
np.savetxt(os.path.join(savePath, fileName), spectrumArray2D[1:,:], delimiter=",")
# If the save path does not exist, create it
if not os.path.exists(savePath):
os.makedirs(savePath)

'''File naming convention: TimeStamp_Patient#_#ofFiles_DataLabel.csv with TimeStamp in the format of MMMDD'''
# timestamp
timeStamp = time.strftime("%b%d")
FileNum = len([name for name in os.listdir(savePath) if os.path.isfile(os.path.join(savePath, name))]) + 1 # Get the file number
patientNum = parameterNode.GetParameter(self.PATIENT_NUM)
fileName = timeStamp + "_" + patientNum + "_" + str(FileNum).zfill(3) + "_" + dataLabel + ".csv"

# fileName = dataLabel + '_' + str(numFiles).zfill(3) + '.csv'
np.savetxt(os.path.join(savePath, fileName), spectrumArray2D[:,:], delimiter=",")
# print sample saved as well as name
print("Sample saved as: " + fileName)

#
# Processing functions
#


def onSpectrumImageNodeModified(self, observer, eventid):
parameterNode = self.getParameterNode()
spectrumImageNode = parameterNode.GetNodeReference(self.INPUT_VOLUME)
Expand Down
Loading

0 comments on commit 336f202

Please sign in to comment.