-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
142 lines (127 loc) · 3.9 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Use a ridge classifier with cross-validation to fit the state matrix (membrane potential) to correctly classify stimuli.
Compute the classification accuracy and the MSE.
"""
import sys
import time
import numpy as np
from helpers import train, reformat_df
from params import *
"""
parse the parameters and init. data structures
"""
runnum = 2 # default num. of repetition is 5
# TODO: runnum 2 for repetition
network_mode = sys.argv[1] # input should be either "noise", "random" or "topo"
delay_mode_intra = sys.argv[2]
delay_mode_inter = sys.argv[3]
delay_intra_param = eval(sys.argv[4])
delay_inter_param = eval(sys.argv[5])
# skipping connections
if len(sys.argv) > 6:
skip_double = bool(eval(sys.argv[6])) # True if double connection is activated
delay_skip_param = eval(sys.argv[7]) # increasing delays
skip_weights = eval(
sys.argv[8]
) # might want to use decreasing weights, as a factor
accuracy_train = np.zeros(
(runnum, module_depth)
) # accuracy for each trial and each module within a trial
MSE_train = np.zeros((runnum, module_depth)) # averaged mean squared error
"""
load all raw data in appropriate formats
"""
for runindex in range(runnum):
if len(sys.argv) > 6:
volt_values = np.load(
PATH
+ "voltvalues_run={}_{}_intra={}{}_inter={}{}_skip_double={}_d={}_w={}.npy".format(
runindex,
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
skip_double,
delay_skip_param,
skip_weights,
)
)
stimuli = np.load(
PATH
+ "stimuli_run={}_{}_intra={}{}_inter={}{}_skip_double={}_d={}_w={}.npy".format(
runindex,
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
skip_double,
delay_skip_param,
skip_weights,
)
)
else:
volt_values = np.load(
PATH
+ "voltvalues_run={}_{}_intra={}{}_inter={}{}.npy".format(
runindex,
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
)
)
stimuli = np.load(
PATH
+ "stimuli_run={}_{}_intra={}{}_inter={}{}.npy".format(
runindex,
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
)
)
"""
train the classifier and test it
"""
accuracy_train[runindex], MSE_train[runindex] = train(
volt_values=volt_values, target_output=stimuli
)
print("training is done: ", time.process_time())
"""
save the summary data to the file
"""
# reformat the numpy array into data frame
df_acc = reformat_df(network_mode, accuracy_train)
df_mse = reformat_df(network_mode, MSE_train)
# save the data frame
df_acc["MSE"] = df_mse["value"]
df_acc.rename(columns={"value": "accuracy"}, inplace=True)
if len(sys.argv) > 6:
df_acc.to_csv(
"training_{}_intra={}{}_inter={}{}_skip_double={}_d={}_w={}.csv".format(
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
skip_double,
delay_skip_param,
skip_weights,
),
index=False,
)
else:
df_acc.to_csv(
"training_{}_intra={}{}_inter={}{}.csv".format(
network_mode,
delay_mode_intra,
delay_intra_param,
delay_mode_inter,
delay_inter_param,
),
index=False,
)