forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AdaptiveMaxPooling2d.cu
478 lines (406 loc) · 13.8 KB
/
AdaptiveMaxPooling2d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/adaptive_max_pool2d_backward_native.h>
#include <ATen/ops/adaptive_max_pool2d_native.h>
#include <ATen/ops/empty.h>
#endif
#include <algorithm>
#include <cfloat>
#include <cmath>
namespace at::native {
namespace {
__device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
return (a / b) * c + ((a % b) * c) / b;
}
__device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
return 1 + ((a + 1) * c - 1) / b;
}
// 4d tensor B x D x H x W
/*
* Description:
* this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
* 4D input, 4D output, 4D argmax x and y
*/
template <typename T>
__global__ void adaptivemaxpool(const T *input, T *output, int64_t *indices,
int isizeH, int isizeW,
int osizeH, int osizeW,
int64_t istrideD, int64_t istrideH, int64_t istrideW)
{
// iterators
int oh, ow;
// compute offsets based on thread/block ID
int o_plane = blockIdx.x;
int i_plane = o_plane;
int ostartW = threadIdx.x;
int oendW = osizeW;
const int ostepW = blockDim.x;
int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
int oendH = osizeH;
const int ostepH = blockDim.y*gridDim.y;
// select input/output plane
output = output + o_plane*osizeH*osizeW;
input = input + i_plane*istrideD;
indices = indices + o_plane*osizeH*osizeW;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
int istartH = start_index(oh, osizeH, isizeH);
int iendH = end_index(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = ostartW; ow < oendW; ow += ostepW) {
int istartW = start_index(ow, osizeW, isizeW);
int iendW = end_index(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the mean of the input image...
const T *ptr_input = input + istartH*istrideH + istartW*istrideW;
T *ptr_output = output + oh*osizeW + ow;
int64_t *ptr_ind = indices + oh*osizeW + ow;
int argmax = istartH * isizeW + istartW;
T max = at::numeric_limits<T>::lower_bound(); // -Infinity
int ih, iw;
for(ih = 0; ih < kH; ih++) {
for(iw = 0; iw < kW; iw++) {
T val = ptr_input[iw*istrideW];
if ((val > max) || at::_isnan(val)) {
max = val;
argmax = (ih+istartH)*isizeW + iw+istartW;
}
}
ptr_input += istrideH; // next input line
}
// Update output and argmax
*ptr_output = max;
*ptr_ind = argmax;
}
}
}
/*
* Description:
* this function computes the gradInput from weight and gradOutput
*/
template <typename T>
__global__ void adaptivemaxgradinput(T *gradInput, const T *gradOutput, const int64_t *indices,
int isizeH, int isizeW,
int osizeH, int osizeW)
{
// iterators
int oh, ow;
// compute offsets based on thread/block ID
int o_plane = blockIdx.x;
int i_plane = o_plane;
//int k = blockIdx.x % sizeD;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
int oendH = osizeH;
int ostepH = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o_plane*osizeH*osizeW;
gradInput = gradInput + i_plane*isizeH*isizeW;
indices = indices + o_plane*osizeH*osizeW;
// compute gradInput
for(oh = ostartH; oh < oendH; oh += ostepH) {
for(ow = ostartW; ow < oendW; ow += ostepW) {
const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
const int64_t *ptr_ind = indices + oh*osizeW + ow;
T z = *ptr_gradOutput;
int argmax = (*ptr_ind);
gradInput[argmax] += z;
}
}
}
/*
* Description:
* this function computes the gradInput from weight and gradOutput
* when kH != dH or kW != dW (uses atomic add)
*/
template <typename T>
__global__ void atomicadaptivemaxgradinput(
T *gradInput, const T *gradOutput, const int64_t *indices,
int isizeH, int isizeW, int osizeH, int osizeW
)
{
// iterators
int oh, ow;
// compute offsets based on thread/block ID
int o_plane = blockIdx.x;
int i_plane = o_plane;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
int oendH = osizeH;
int ostepH = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o_plane*osizeH*osizeW;
gradInput = gradInput + i_plane*isizeH*isizeW;
indices = indices + o_plane*osizeH*osizeW;
// compute gradInput
for(oh = ostartH; oh < oendH; oh += ostepH) {
for(ow = ostartW; ow < oendW; ow += ostepW) {
const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
const int64_t *ptr_ind = indices + oh*osizeW + ow;
T z = *ptr_gradOutput;
int argmax = (*ptr_ind);
// atomic add since different threads could update same variable
gpuAtomicAddNoReturn(&(gradInput[argmax]), z);
}
}
}
} // namespace
// 4d tensor B x D x H x W
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cuda)
(const Tensor& input,
IntArrayRef output_size,
const Tensor& output,
const Tensor& indices) {
TensorArg output_arg{output, "output", 1};
TensorArg indices_arg{indices, "indices", 2};
TensorArg input_arg{input, "input", 3};
checkAllSameGPU(
__func__, {output_arg, indices_arg, input_arg});
if (input.numel() == 0) {
return;
}
int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);
int64_t istrideD = input.stride(0);
int64_t istrideH = input.stride(1);
int64_t istrideW = input.stride(2);
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
// run maxpool kernel
adaptivemaxpool<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_data,
output_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW,
istrideD,
istrideH,
istrideW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
Tensor input_ = input.contiguous();
int64_t sizeB = input_.size(0);
int64_t sizeD = input_.size(1);
int64_t isizeH = input_.size(2);
int64_t isizeW = input_.size(3);
// In the kernel, the batch and channel dimensions are treated as if they
// are flattened and istrideD is used as the stride of this flattened dim
// Handle the edge case where input_.size(1) == 1, where despite passing the
// contiguity check the stride might not be H * W
int64_t istrideD = isizeH * isizeW;
int64_t istrideH = input_.stride(2);
int64_t istrideW = input_.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
input_.scalar_type(),
"adaptive_max_pool2d_cuda",
[&] {
const scalar_t* input_data = input_.const_data_ptr<scalar_t>();
scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB * sizeD, blocksH);
dim3 threads(32, 8);
// run maxpool kernel
adaptivemaxpool<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_data,
output_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW,
istrideD,
istrideH,
istrideW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
if (!output.is_contiguous()) {
output.copy_(output_c);
}
if (!indices.is_contiguous()) {
indices.copy_(indices_c);
}
}
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
(const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
const Tensor& gradInput) {
globalContext().alertNotDeterministic(
"adaptive_max_pool2d_backward_cuda");
TensorArg grad_input_arg{gradInput, "gradInput", 1};
TensorArg grad_output_arg{gradOutput, "gradOutput", 2};
TensorArg input_arg{input, "input", 3};
TensorArg indices_arg{indices, "indices", 4};
checkAllSameGPU(
__func__,
{grad_input_arg, grad_output_arg, input_arg, indices_arg});
if (gradOutput.numel() == 0) {
return;
}
bool atomic =
true; // suboptimal, but without atomic it doesn't pass the tests
const at::Tensor gradOutput_ = gradOutput.contiguous();
const at::Tensor indices_ = indices.contiguous();
const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);
int64_t osizeH = gradOutput_.size(1);
int64_t osizeW = gradOutput_.size(2);
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
gradInput_c.zero_();
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
if (atomic) {
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
gradInput_data,
gradOutput_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
// run updateGradInput kernel
atomicadaptivemaxgradinput<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
gradInput_data,
gradOutput_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
} else {
int64_t sizeB = input.size(0);
int64_t sizeD = input.size(1);
int64_t isizeH = input.size(2);
int64_t isizeW = input.size(3);
int64_t osizeH = gradOutput_.size(2);
int64_t osizeW = gradOutput_.size(3);
gradInput_c.zero_();
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB * sizeD, blocksH);
dim3 threads(32, 8);
if (atomic) {
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
gradInput_data,
gradOutput_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
// run updateGradInput kernel, accumulate gradients atomically
adaptivemaxgradinput<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
gradInput_data,
gradOutput_data,
indices_data,
isizeH,
isizeW,
osizeH,
osizeW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
if (!gradInput.is_contiguous()) {
gradInput.copy_(gradInput_c);
}
}
} // namespace at::native