forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
SpatialMaxPooling.cu
235 lines (203 loc) · 8.4 KB
/
SpatialMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#include "utils.h"
#include "common.h"
// kernels borrowed from Caffe
template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height);
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Dtype maxval = -FLT_MAX;
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (bottom_data[h * width + w] > maxval) {
maxidx = h * width + w;
maxval = bottom_data[maxidx];
}
}
}
top_data[index] = maxval;
top_mask[index] = maxidx + 1;
}
}
template <typename Dtype>
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] - 1 == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
}
}
}
bottom_diff[index] = gradient;
}
}
static int cunn_SpatialMaxPooling_updateOutput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
bool ceil_mode = luaT_getfieldcheckboolean(L, 1, "ceil_mode");
THCudaTensor *output = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 3, input, output, indices));
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
luaL_argcheck(L, nInputCols >= kW - padW && nInputRows >= kH - padH, 2, "input image smaller than kernel size");
luaL_argcheck(L, kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
input = THCudaTensor_newContiguous(state, input);
float* input_data = THCudaTensor_data(state, input);
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_resizeAs(state, indices, output);
float* indices_data = THCudaTensor_data(state, indices);
float* output_data = THCudaTensor_data(state, output);
int count = THCudaTensor_nElement(state, output);
MaxPoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, output_data, indices_data);
if(input->nDimension == 3)
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_free(state, input);
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in SpatialMaxPooling.updateOutput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
return 1;
}
static int cunn_SpatialMaxPooling_updateGradInput(lua_State *L)
{
THCState *state = getCutorchState(L);
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
bool ceil_mode = luaT_getfieldcheckboolean(L, 1, "ceil_mode");
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 4, input, gradOutput, indices, gradInput));
input = THCudaTensor_newContiguous(state, input);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
THCudaTensor_resizeAs(state, gradInput, input);
int count = THCudaTensor_nElement(state, input);
MaxPoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCudaTensor_data(state, gradOutput),
THCudaTensor_data(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW,
THCudaTensor_data(state, gradInput));
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in SpatialMaxPooling.updateGradInput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
// clean
THCudaTensor_free(state, input);
THCudaTensor_free(state, gradOutput);
return 1;
}
static const struct luaL_Reg cunn_SpatialMaxPooling__ [] = {
{"SpatialMaxPooling_updateOutput", cunn_SpatialMaxPooling_updateOutput},
{"SpatialMaxPooling_updateGradInput", cunn_SpatialMaxPooling_updateGradInput},
{NULL, NULL}
};
void cunn_SpatialMaxPooling_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_SpatialMaxPooling__, "nn");
lua_pop(L,1);
}