forked from kulinseth/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AveragePool2d.cu
458 lines (411 loc) · 16.4 KB
/
AveragePool2d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/native/Pool.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/macros/Macros.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#endif
namespace at::native {
namespace {
__device__ inline int min(int a, int b) {
return a <= b ? a : b;
}
__device__ inline int max(int a, int b) {
return a >= b ? a : b;
}
template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool2d_out_cuda_frame(const int nthreads,
const scalar_t* const bottom_data, const int64_t channels,
const int64_t height, const int64_t width, const int64_t pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
scalar_t* const top_data, const int divisor_override,
const bool count_include_pad, const bool use_divisor) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
const int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
if (hstart >= hend || wstart >= wend) {
top_data[index] = scalar_t(0);
continue;
}
accscalar_t aveval = accscalar_t(0);
const scalar_t* const bottom_slice = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
aveval += bottom_slice[h * width + w];
}
}
int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (hend - hstart) * (wend - wstart);
}
}
top_data[index] = static_cast<scalar_t>(aveval / divide_factor);
}
}
template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool2d_out_cuda_frame_nhwc(const int nthreads,
const scalar_t* const bottom_data, const int64_t channels,
const int64_t height, const int64_t width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
scalar_t* const top_data, const int divisor_override,
const bool count_include_pad, const bool use_divisor) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int c = index % channels;
const int pw = (index / channels) % pooled_width;
const int ph = (index / channels / pooled_width) % pooled_height;
const int n = index / channels / pooled_width / pooled_height;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
const int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
if (hstart >= hend || wstart >= wend) {
top_data[index] = scalar_t(0);
continue;
}
accscalar_t aveval = accscalar_t(0);
const scalar_t* const bottom_slice = bottom_data + n * channels * height * width + c;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
aveval += bottom_slice[(h * width + w) * channels];
}
}
int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (hend - hstart) * (wend - wstart);
}
}
top_data[index] = static_cast<scalar_t>(aveval / divide_factor);
}
}
template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool2d_backward_out_cuda_frame(const int nthreads, const scalar_t* const top_diff,
const int64_t channels, const int64_t height,
const int64_t width, const int64_t pooled_height, const int64_t pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
scalar_t* const bottom_diff, const int divisor_override,
bool count_include_pad, bool use_divisor) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
const int w = index % width + pad_w;
const int h = (index / width) % height + pad_h;
const int c = (index / width / height) % channels;
const int n = index / width / height / channels;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
accscalar_t gradient = accscalar_t(0);
const scalar_t* const top_diff_slice =
top_diff + (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
if (hstart >= hend || wstart >= wend) {
continue;
}
int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (hend - hstart) * (wend - wstart);
}
}
gradient += top_diff_slice[ph * pooled_width + pw] / divide_factor;
}
}
bottom_diff[index] = static_cast<scalar_t>(gradient);
}
}
template <typename scalar_t, typename accscalar_t>
__global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads,
const scalar_t* const top_diff,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
scalar_t* const bottom_diff, const int divisor_override,
bool count_include_pad, bool use_divisor) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int c = index % channels;
const int w = (index / channels) % width;
const int h = (index / channels / width) % height;
const int n = index / channels / width / height;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
accscalar_t gradient = accscalar_t(0);
const scalar_t* const top_diff_slice = top_diff + n * channels * pooled_height * pooled_width + c;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
if (hstart >= hend || wstart >= wend) {
continue;
}
int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (hend - hstart) * (wend - wstart);
}
}
gradient += top_diff_slice[(ph * pooled_width + pw) * channels] / divide_factor;
}
}
bottom_diff[index] = static_cast<scalar_t>(gradient);
}
}
} // anonymous namespace
TORCH_IMPL_FUNC(avg_pool2d_out_cuda)
(const Tensor& input_,
int64_t kH_,
int64_t kW_,
int64_t dH_,
int64_t dW_,
int64_t padH_,
int64_t padW_,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& output) {
TensorArg output_arg{ output, "output", 1 };
TensorArg input_arg{ input_, "input_", 2 };
checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg});
const int kH = safe_downcast<int, int64_t>(kH_);
const int kW = safe_downcast<int, int64_t>(kW_);
const int dH = safe_downcast<int, int64_t>(dH_);
const int dW = safe_downcast<int, int64_t>(dW_);
const int padH = safe_downcast<int, int64_t>(padH_);
const int padW = safe_downcast<int, int64_t>(padW_);
/* sizes */
const int64_t nInputPlane = input_.size(-3);
const int64_t inputHeight = input_.size(-2);
const int64_t inputWidth = input_.size(-1);
int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
const auto memory_format = input_.suggest_memory_format();
Tensor input = input_.contiguous(memory_format);
const int32_t count = safe_downcast<int32_t, int64_t>(output.numel());
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
bool use_divisor = divisor_override.has_value();
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
if (count != 0) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool2d_out_cuda_frame",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *output_data = output.data_ptr<scalar_t>();
scalar_t *input_data = input.data_ptr<scalar_t>();
switch (memory_format){
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks,
num_threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
count,
input_data,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
output_data,
divisor_override_value,
count_include_pad,
use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
case MemoryFormat::Contiguous: {
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
<<<num_blocks,
num_threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
count,
input_data,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH,
kW,
dH,
dW,
padH,
padW,
output_data,
divisor_override_value,
count_include_pad,
use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
);
}
}
TORCH_IMPL_FUNC(avg_pool2d_backward_out_cuda) (
const Tensor& gradOutput_,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
c10::optional<int64_t> divisor_override,
const Tensor& gradInput
) {
TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 };
TensorArg input_arg{ input_, "input_", 3 };
checkAllSameGPU("avg_pool2d_backward_out_cuda",
{gradInput_arg, gradOutput_arg, input_arg});
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
const auto memory_format = input_.suggest_memory_format();
const Tensor input = input_.contiguous(memory_format);
const Tensor gradOutput = gradOutput_.contiguous(memory_format);
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
const int32_t count = safe_downcast<int32_t, int64_t>(input.numel());
if (count == 0) {
return;
}
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
const uint32_t num_blocks = ceil_div<uint32_t>(count, num_threads);
bool use_divisor = divisor_override.has_value();
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool2d_backward_out_cuda_frame",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
switch (memory_format) {
case MemoryFormat::ChannelsLast: {
gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_backward_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
case MemoryFormat::Contiguous: {
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
);
}
} // namespace at::native