-
Notifications
You must be signed in to change notification settings - Fork 2
/
conv_layer.h
124 lines (100 loc) · 3.35 KB
/
conv_layer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include "filler.h"
#include "im2col.h"
#include "layer.h"
#include "param.h"
#include "util.h"
namespace con {
class ConvolutionalLayer : public Layer {
public:
ConvolutionalLayer(
const string &name,
const int &depth, const int &kernel, const int &stride, const int &padding,
Layer *prev,
Filler *weightFiller,
Filler *biasFiller) :
Layer(
name,
prev->num,
(prev->width - kernel + 2 * padding) / stride + 1,
(prev->height - kernel + 2 * padding) / stride + 1,
depth,
prev),
kernel(kernel), kernelArea(sqr(kernel)), stride(stride), padding(padding),
weight(kernelArea * inDepth * depth, weightFiller),
bias(depth * height * width, biasFiller) {
biasMultiplier = Vec(height * width, 1.0);
col.resize(width * height * inDepth * kernelArea);
}
const int kernel;
const int kernelArea;
const int stride;
const int padding;
Param weight;
Param bias;
// (1, width * height) ones matrix.
Vec biasMultiplier;
Vec col;
void forward() {
for (int n = 0; n < num; n++) {
forwardOnce(prev->output[n], &output[n]);
}
}
void forwardOnce(const Vec &input, Vec *output) {
im2col(input, inDepth, inHeight, inWidth, kernel, padding, stride, &col);
gemm(
CblasNoTrans, CblasNoTrans,
depth, width * height, kernelArea * inDepth,
1., weight.value, col,
0., output);
gemm(
CblasNoTrans, CblasNoTrans,
depth, width * height, 1,
1., bias.value, biasMultiplier,
1., output);
}
void backProp(const vector<Vec> &nextErrors) {
clear(&weight.delta);
clear(&bias.delta);
clear(&errors);
for (int n = 0; n < num; n++) {
backPropOnce(prev->output[n], nextErrors[n], &errors[n]);
}
}
void backPropOnce(const Vec &input, const Vec &nextErrors, Vec *errors) {
backPropBias(nextErrors, &bias.delta);
backPropInput(nextErrors, weight.value, errors);
backPropWeight(nextErrors, input, &weight.delta);
}
void backPropBias(const Vec &nextErrors, Vec *biasDelta) {
gemv(
CblasNoTrans,
depth, width * height,
1., nextErrors, biasMultiplier,
1., biasDelta);
}
void backPropInput(const Vec &nextErrors, const Vec &weight, Vec *errors) {
if (name == "conv1") {
return;
}
gemm(
CblasTrans, CblasNoTrans,
kernelArea * inDepth, width * height, depth,
1., weight, nextErrors,
0., &col);
col2im(col, inDepth, inHeight, inWidth, kernel, padding, stride, errors);
}
void backPropWeight(const Vec &nextErrors, const Vec &input, Vec *delta) {
im2col(input, inDepth, inHeight, inWidth, kernel, padding, stride, &col);
gemm(
CblasNoTrans, CblasTrans,
depth, kernelArea * inDepth, width * height,
1., nextErrors, col,
1., delta);
}
void applyUpdate(const Real &lr, const Real &momentum, const Real &decay) {
weight.update(lr, momentum, decay);
bias.update(lr, momentum, decay);
}
};
}