-
Notifications
You must be signed in to change notification settings - Fork 2
/
fully_connected_layer.h
109 lines (87 loc) · 2.59 KB
/
fully_connected_layer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#pragma once
#include "filler.h"
#include "layer.h"
#include "util.h"
namespace con {
class FullyConnectedLayer : public Layer {
public:
FullyConnectedLayer(
const string &name,
const int &depth,
Layer *prev,
Filler *weightFiller,
Filler *biasFiller) :
Layer(name, prev->num, 1, 1, depth, prev),
inputSize(prev->width * prev->height * prev->depth),
weight(depth * inputSize, weightFiller),
bias(depth, biasFiller) {
biasMultiplier = Vec(num, 1.0);
flatInput.resize(num * inputSize);
flatOutput.resize(num * depth);
flatNextErrors.resize(num * depth);
flatErrors.resize(num * inputSize);
}
const int inputSize;
Param weight;
Param bias;
Vec biasMultiplier;
Vec flatInput;
Vec flatOutput;
Vec flatNextErrors;
Vec flatErrors;
void flatten(const vector<Vec> a, Vec *b) {
int i = 0;
for (int j = 0; j < a.size(); j++) {
for (int k = 0; k < a[j].size(); k++) {
b->at(i++) = a[j][k];
}
}
}
void reconstruct(const Vec &a, vector<Vec> *b) {
int i = 0;
for (int j = 0; j < b->size(); j++) {
for (int k = 0; k < b->at(j).size(); k++) {
b->at(j)[k] = a[i++];
}
}
}
void forward() {
flatten(prev->output, &flatInput);
gemm(
CblasNoTrans, CblasTrans,
num, depth, inputSize,
1., flatInput, weight.value,
0., &flatOutput);
gemm(
CblasNoTrans, CblasNoTrans,
num, depth, 1,
1., biasMultiplier, bias.value,
1., &flatOutput);
reconstruct(flatOutput, &output);
}
void backProp(const vector<Vec> &nextErrors) {
clear(&errors);
flatten(nextErrors, &flatNextErrors);
gemm(
CblasTrans, CblasNoTrans,
depth, inputSize, num,
1., flatNextErrors, flatInput,
1., &weight.delta);
gemv(
CblasTrans,
num, depth,
1., flatNextErrors, biasMultiplier,
1., &bias.delta);
gemm(
CblasNoTrans, CblasNoTrans,
num, inputSize, depth,
1., flatNextErrors, weight.value,
0., &flatErrors);
reconstruct(flatErrors, &errors);
}
void applyUpdate(const Real &lr, const Real &momentum, const Real &decay) {
weight.update(lr, momentum, decay);
bias.update(lr, momentum, decay);
}
};
}