-
Notifications
You must be signed in to change notification settings - Fork 12
/
elm_train.m
79 lines (61 loc) · 1.92 KB
/
elm_train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
function model = elm_train(X_train, Y_train, option)
%% input option
n_hidden_nodes = option.n_hidden_nodes;
act_func = option.act_func;
c_rho = 2^option.c_rho;
elm_type = option.elm_type;
% elm_type = 'regression';
% elm_type = 'classifier';
if isfield(option, 'seed')
seed = option.seed;
else
seed = fix(mod(cputime,100));
end
rand('seed',seed);
%% Woad training dataset
T=Y_train;
NumberofTrainingData=size(X_train,1);
NumberofInputNeurons=size(X_train,2);
T=double(T);
t1=tic;
%% Calculate weights & biases
[H, InputWeight, BiasHidden] = elm_hidden_layer_gen(X_train, n_hidden_nodes, act_func, seed);
%% Calculate output weights OutputWeight (beta_i)
n = n_hidden_nodes;
% OutputWeight=((H'*H+(eye(n)/c_rho))\(H'*T));
if size(H,1) > size(H,2)
HH = H'*H;
HT = H'*T;
OutputWeight=((HH+(eye(n)/c_rho))\(HT));
else
HH = H*H';
OutputWeight=H'*((HH+(eye(size(H,1))/c_rho))\(T));
end
TrainTime = toc(t1);
% TrainTime=toc;
%%%%%%%%%%% Calculate the training accuracy
pred = (H * OutputWeight);
if strcmp(elm_type, 'classifier')
%%%%%%%%%% Calculate training & testing classification accuracy
missclassified=0;
for i = 1 : size(X_train, 1)
[x, label_index_expected]=max(pred(i,:));
[x, label_index_actual]=max(Y_train(i,:));
if label_index_actual ~= label_index_expected
missclassified = missclassified + 1;
end
end
TrainEVAL = 1-missclassified/NumberofTrainingData;
elseif strcmp(elm_type, 'regression')
TrainEVAL = sqrt(mse(Y_train - pred));
end
model.elm_type = elm_type;
model.InputWeight = InputWeight;
model.n_hidden_nodes = n_hidden_nodes;
model.N = n_hidden_nodes;
model.BiasHidden = BiasHidden;
model.OutputWeight = OutputWeight;
model.act_func = act_func;
model.c_rho = option.c_rho;
model.TrainTime = TrainTime;
model.TrainEVAL = TrainEVAL;