Skip to content

Commit

Permalink
updated NeuralNetworks
Browse files Browse the repository at this point in the history
  • Loading branch information
catubc committed Mar 23, 2020
1 parent e1a8de3 commit 78495c6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 49 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
86 changes: 37 additions & 49 deletions src/yass/yass_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class plot_widget:
def __init__(self, window):
self.window = window

# voltage window
# voltage window
self.fig2 = Figure(figsize=(6,3))
self.a2 = self.fig2.add_subplot(111)
self.a2.set_yticks([])
Expand All @@ -28,7 +28,7 @@ def __init__(self, window):
self.canvas2.get_tk_widget().place(x = 225, y = 300)#, relwidth = , relheight = 1)
self.canvas2.draw()

# geometry window
# geometry window
self.fig1 = Figure(figsize=(3.0,3))
self.a = self.fig1.add_subplot(111)
self.a.set_ylabel("um", fontsize=8)
Expand Down Expand Up @@ -69,6 +69,7 @@ def set_filename_denoise(self):
self.nn_denoise_box.delete(0, 'end')
self.nn_denoise_box.insert(0, self.nn_denoise_txt)


def refresh(self):

# yaml can't save tuples, so we save them as strings
Expand All @@ -83,26 +84,28 @@ def refresh(self):

self.config_params['neuralnetwork']['denoise']['filter_sizes'] = \
str(self.config_params['neuralnetwork']['denoise']['filter_sizes'])

# update entries from screen


# update single entries from screen
self.root_dir = self.root_dir_box.get()
#self.config_params['data']['root_folder']+self.config_params['data']['recordings']
#self.data_root + self.config_params['data']['recordings']

self.sample_rate = float(self.sample_rate_box.get())
self.config_params['recordings']['sampling_rate'] = self.sample_rate
self.config_params['recordings']['sampling_rate'] = int(self.sample_rate)

self.n_chan = int(self.n_chan_box.get())
self.config_params['recordings']['n_channels'] = self.n_chan

self.radius = float(self.radius_box.get())
self.config_params['recordings']['spatial_radius'] = self.radius

# redraw everything
# plot geometry file
self.spike_size_ms = float(self.spike_size_ms_box.get())
self.config_params['recordings']['spike_size_ms'] = self.spike_size_ms

# redraw everything
# plot geometry file
self.window.filemenu.plot.plot_geom()

# load snipit of data and visualize
# load snipit of data and visualize
self.window.filemenu.plot.plot_voltage()

with open(self.fname_config[:-5]+"_modified.yaml", 'w') as f:
Expand All @@ -118,6 +121,7 @@ def refresh(self):
# Replace the target string
filedata = filedata.replace("'[", "[")
filedata = filedata.replace("]'", "]")
filedata = filedata.replace("null", "")

# Write the file out again
with open(self.fname_config[:-5]+"_modified.yaml", 'w') as file:
Expand Down Expand Up @@ -160,12 +164,7 @@ def refresh_button(self, txt, x, y):
#self.filename_loaded = txt
button_refresh = Button(self.window,text=txt,command=lambda:self.refresh())
button_refresh.place(x=x, y=y)

# reread all text data into

# reload everything
# self.load_config()



def display_train_run_button(self, label_txt, txt, x, y):
#self.filename_loaded = txt
Expand All @@ -176,9 +175,14 @@ def display_train_run_button(self, label_txt, txt, x, y):
dirname=Entry(self.window)
dirname.insert(0,txt)
dirname.place(x=x+95,y=y)


def run(self):

# first save the file
self.refresh()

# run yass
cmd = "yass sort "+ self.fname_config[:-5]+"_modified.yaml"

returned_value = os.system(cmd) # returns the exit code in unix
Expand Down Expand Up @@ -213,11 +217,11 @@ def display_metadata_and_buttons(self):

if self.config_params['data']['root_folder']=='./':
self.data_root = os.path.split(self.fname_config)[0]+'/'
self.config_params['data']['root_folder'] = self.data_root
else:
self.data_root = self.config_params['data']['root_folder']

#print ("SELF data root: ", self.data_root)


# display root directory
#text_ = self.config_params['data']['root_folder']+self.config_params['data']['recordings']
text_ = self.data_root+self.config_params['data']['recordings']
Expand All @@ -238,19 +242,25 @@ def display_metadata_and_buttons(self):
text_= self.config_params['recordings']['spatial_radius']
self.radius_box = self.display_single("Neighbour dist ", text_, 70, 120)
self.radius = float(text_)

# load radius of local chans:
text_= self.config_params['recordings']['spike_size_ms']
self.spike_size_ms_box = self.display_single("Spk width ", text_, 70, 150)
self.spike_size_ms = float(text_)


# Refresh button
#text_ = config_params['neuralnetwork']['detect']['filename']
self.refresh_button("Update config", 500, 30)

# display NN detect
text_ = self.config_params['neuralnetwork']['detect']['filename']
self.nn_detect_button("NN detect", text_, 0, 160)
self.nn_detect_button("NN detect", text_, 0, 180)
self.nn_detect = text_

# display NN denoise
text_ = self.config_params['neuralnetwork']['denoise']['filename']
self.nn_denoise_button("NN detect", text_, 0, 190)
self.nn_denoise_button("NN detect", text_, 0, 210)
self.nn_denoise = text_

# run YASS button
Expand Down Expand Up @@ -282,30 +292,27 @@ def plot_geom(self):

self.a.set_title("Geometry (partial)\n(centre chan + neigbhours)", fontsize=8)

# load geometry file
# load geometry file
geom_file = self.data_root + self.config_params['data']['geometry']
print (geom_file)
#print (geom_file)
geom = np.loadtxt(geom_file)
#print ("Geom: ", geom)

self.a.scatter(geom[:,0],geom[:,1],s=10, color='black')

#find nearest chan to middle:
#find nearest chan to middle:
middle_chan_x = np.mean(geom[:,0])
middle_chan_y = np.mean(geom[:,1])
mid_chan_id = self.closest_node([middle_chan_x, middle_chan_y], geom)

self.a.scatter(geom[mid_chan_id,0],geom[mid_chan_id,1],s=150,color='red')

# find chans within radius
# find chans within radius
local_chans_id = self.chans_within_radius(geom[mid_chan_id], geom, self.radius)

self.a.scatter(geom[local_chans_id,0],geom[local_chans_id,1],s=50, color='blue')

# zoom on on arrays to show middle chan + neighbour chans as example
#print (np.min(geom[local_chans_id,0]), np.min(geom[local_chans_id,1]))
#print (np.max(geom[local_chans_id,0]), np.max(geom[local_chans_id,1]))

# zoom on on arrays to show middle chan + neighbour chans as example
spacer = 60
if np.min(geom[local_chans_id,0])>0:
min_ = np.min(geom[local_chans_id,0])-spacer
Expand Down Expand Up @@ -375,7 +382,7 @@ def plot_voltage(self):
start_ch=0

x = np.arange(end-start)/float(self.sampling_rate)
print (x.shape, end-start)
#print (x.shape, end-start)
for c in range(start_ch,min(start_ch+10, start_ch+self.n_channels)):
self.a2.plot(x, rawdata2D[start:end,c]+c*100,c='black')

Expand All @@ -394,7 +401,7 @@ def parse_yaml(self, filename):
# scalar values to Python the dictionary format
self.config_params = yaml.load(f, Loader=yaml.FullLoader)

print (self.config_params)
#print (self.config_params)
# laod meta data
self.window.filemenu.plot.display_metadata_and_buttons()

Expand All @@ -405,22 +412,3 @@ def parse_yaml(self, filename):
self.window.filemenu.plot.plot_voltage()


# root = Tk()
# root.title('YASS')
# root.geometry("800x600") #You want the size of the app to be 500x500
# root.resizable(0, 0)

# # initialize plotting widget
# plot = plot_widget(root)

# # initialize menu widget
# menubar = Menu(root)

# # add menu items
# root.filemenu = Menu(menubar, tearoff=0)
# root.filemenu.plot = plot
# root.filemenu.add_command(label="Open", command=plot.load_config)
# menubar.add_cascade(label="File", menu=root.filemenu)

# root.config(menu=menubar)
# root.mainloop()

0 comments on commit 78495c6

Please sign in to comment.