forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
SpatialConvolutionMM.cu
462 lines (399 loc) · 17.5 KB
/
SpatialConvolutionMM.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
#include "utils.h"
#include "common.h"
// Kernel for fast unfold+copy
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
__global__ void im2col_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, const int height_col, const int width_col,
float* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * ksize_h * ksize_w;
int h_in = h_out * stride_h - pad_h;
int w_in = w_out * stride_w - pad_w;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize_h; ++i) {
for (int j = 0; j < ksize_w; ++j) {
int h = h_in + i;
int w = w_in + j;
*data_col = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im[i * width + j] : 0;
data_col += height_col * width_col;
}
}
}
}
void im2col(cudaStream_t stream, const float* data_im, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, float* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad_h - ksize_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - ksize_w) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
// Launch
im2col_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
num_kernels, data_im, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_col
);
}
__global__ void col2im_kernel(const int n, const float* data_col,
const int height, const int width, const int channels, const int patch_h, const int patch_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int height_col, const int width_col,
float* data_im) {
CUDA_KERNEL_LOOP(index, n) {
float val = 0;
int w = index % width + pad_w;
int h = (index / width) % height + pad_h;
int c = index / (width * height);
// compute the start and end of the output
int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
int w_col_end = min(w / stride_w + 1, width_col);
int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
int h_col_end = min(h / stride_h + 1, height_col);
/*
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize + (w - w_col * stride_w);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
*/
// equivalent implementation
int offset = (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col;
int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col;
int coeff_w_col = (1 - stride_w * height_col * width_col);
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
}
}
data_im[index] = val;
}
}
void col2im(cudaStream_t stream, const float* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, float* data_im) {
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
num_kernels, data_col, height, width, channels,
patch_h, patch_w, pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_im
);
}
static int cunn_SpatialConvolutionMM_updateOutput(lua_State *L) {
THCState *state = getCutorchState(L);
// Input
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
// Params:
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
THCudaTensor *weight = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "weight", "torch.CudaTensor");
THCudaTensor *bias = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "bias", "torch.CudaTensor");
THCudaTensor *columns = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "finput", "torch.CudaTensor");
THCudaTensor *ones = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "fgradInput", "torch.CudaTensor");
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 6, input, output, weight,
bias, columns, ones));
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
int batch = 1;
if (input->nDimension == 3) {
luaL_argcheck(L, input->size[0] == nInputPlane, 2, "input channels and nInputPlane dont match");
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
} else {
luaL_argcheck(L, input->size[1] == nInputPlane, 2, "input channels and nInputPlane dont match");
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
if (outputWidth < 1 || outputHeight < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
// Batch size + input planes
long batchSize = input->size[0];
// Resize output
THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight, outputWidth);
// Resize temporary columns
THCudaTensor_resize2d(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Define a buffer of ones, for bias accumulation
// Note: this buffer can be shared with other modules, it only ever gets increased,
// and always contains ones.
if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
// Resize plane and fill with ones...
THCudaTensor_resize2d(state, ones, outputHeight, outputWidth);
THCudaTensor_fill(state, ones, 1);
}
// Helpers
THCudaTensor *input_n = THCudaTensor_new(state);
THCudaTensor *output_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
THCudaTensor_select(state, input_n, input, 0, elt);
THCudaTensor_select(state, output_n, output, 0, elt);
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m_ = nOutputPlane;
long n_ = outputHeight * outputWidth;
long k_ = 1;
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_gemm(
state,
't', 'n',
n_, m_, k_,
1,
THCudaTensor_data(state, ones), k_,
THCudaTensor_data(state, bias), k_,
0,
THCudaTensor_data(state, output_n), n_
);
// Extract columns:
im2col(
THCState_getCurrentStream(state),
THCudaTensor_data(state, input_n),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
THCudaTensor_data(state, columns)
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = weight->size[0];
long n = columns->size[1];
long k = weight->size[1];
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_gemm(
state,
'n', 'n',
n, m, k,
1,
THCudaTensor_data(state, columns), n,
THCudaTensor_data(state, weight), k,
1,
THCudaTensor_data(state, output_n), n
);
}
// Free
THCudaTensor_free(state, input_n);
THCudaTensor_free(state, output_n);
// Resize output
if (batch == 0) {
THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
}
// return output
return 1;
}
static int cunn_SpatialConvolutionMM_updateGradInput(lua_State *L) {
THCState *state = getCutorchState(L);
// Inputs
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// Params
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
THCudaTensor *weight = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "weight", "torch.CudaTensor");
THCudaTensor *gradColumns = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "finput", "torch.CudaTensor");
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 5, input, gradOutput, weight,
gradColumns, gradInput));
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
int batch = 1;
if (input->nDimension == 3) {
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
// Batch size + input planes
long batchSize = input->size[0];
// Resize output
THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
// Resize temporary columns
THCudaTensor_resize2d(state, gradColumns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Helpers
THCudaTensor *gradInput_n = THCudaTensor_new(state);
THCudaTensor *gradOutput_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per sample:
THCudaTensor_select(state, gradInput_n, gradInput, 0, elt);
THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = weight->size[1];
long n = gradColumns->size[1];
long k = weight->size[0];
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_gemm(
state,
'n', 't',
n, m, k,
1,
THCudaTensor_data(state, gradOutput_n), n,
THCudaTensor_data(state, weight), m,
0,
THCudaTensor_data(state, gradColumns), n
);
// Unpack columns back into input:
col2im(
THCState_getCurrentStream(state),
THCudaTensor_data(state, gradColumns),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
THCudaTensor_data(state, gradInput_n)
);
}
// Free
THCudaTensor_free(state, gradInput_n);
THCudaTensor_free(state, gradOutput_n);
// Resize output
if (batch == 0) {
THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
THCudaTensor_resize3d(state, gradInput, nInputPlane, inputHeight, inputWidth);
}
// Return gradInput
return 1;
}
static int cunn_SpatialConvolutionMM_accGradParameters(lua_State *L) {
THCState *state = getCutorchState(L);
// Inputs
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// Params
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int kW = luaT_getfieldcheckint(L, 1, "kW");
int kH = luaT_getfieldcheckint(L, 1, "kH");
int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
float scale = luaL_optnumber(L, 4, 1);
THCudaTensor *gradWeight = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradWeight", "torch.CudaTensor");
THCudaTensor *gradBias = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradBias", "torch.CudaTensor");
THCudaTensor *columns = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "finput", "torch.CudaTensor");
THCudaTensor *ones = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "fgradInput", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 6, input, gradOutput, gradWeight,
gradBias, columns, ones));
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
int batch = 1;
if (input->nDimension == 3) {
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
// Batch size + input planes
long batchSize = input->size[0];
// Define a buffer of ones, for bias accumulation
if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
// Resize plane and fill with ones...
THCudaTensor_resize2d(state, ones, outputHeight, outputWidth);
THCudaTensor_fill(state, ones, 1);
}
// Resize temporary columns
THCudaTensor_resize2d(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Helpers
THCudaTensor *input_n = THCudaTensor_new(state);
THCudaTensor *gradOutput_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
THCudaTensor_select(state, input_n, input, 0, elt);
THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
// Extract columns:
im2col(
THCState_getCurrentStream(state),
THCudaTensor_data(state, input_n),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
THCudaTensor_data(state, columns)
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = gradWeight->size[0];
long n = gradWeight->size[1];
long k = columns->size[1];
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_gemm(
state,
't', 'n',
n, m, k,
scale,
THCudaTensor_data(state, columns), k,
THCudaTensor_data(state, gradOutput_n), k,
1,
THCudaTensor_data(state, gradWeight), n
);
// Do Bias:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m_ = nOutputPlane;
long k_ = outputHeight * outputWidth;
// Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
THCudaBlas_gemv(
state,
't',
k_, m_,
scale,
THCudaTensor_data(state, gradOutput_n), k_,
THCudaTensor_data(state, ones), 1,
1,
THCudaTensor_data(state, gradBias), 1
);
}
// Free
THCudaTensor_free(state, input_n);
THCudaTensor_free(state, gradOutput_n);
// Resize
if (batch == 0) {
THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
}
// Return nothing
return 0;
}
static const struct luaL_Reg cunn_SpatialConvolutionMM__ [] = {
{"SpatialConvolutionMM_updateOutput", cunn_SpatialConvolutionMM_updateOutput},
{"SpatialConvolutionMM_updateGradInput", cunn_SpatialConvolutionMM_updateGradInput},
{"SpatialConvolutionMM_accGradParameters", cunn_SpatialConvolutionMM_accGradParameters},
{NULL, NULL}
};
void cunn_SpatialConvolutionMM_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_SpatialConvolutionMM__, "nn");
lua_pop(L,1);
}