forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
SpatialDepthwiseConvolution.cu
268 lines (236 loc) · 10.6 KB
/
SpatialDepthwiseConvolution.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
// updateOutput, updateGradInput Kernels ported from Sergey Zagoruyko's pyinn, which itself was a
// port from Caffe
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <THC/THCDeviceTensor.cuh>
#include <THC/THCDeviceTensorUtils.cuh>
#include <THC/THCNumerics.cuh>
#include <THC/THCReduceApplyUtils.cuh>
#include <THC/THCSortUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THCUNN/SharedMem.cuh>
#include <THCUNN/common.h>
#include <algorithm>
#include <c10/macros/Macros.h>
// Crude benchmarks suggest 256 is better than 512 and 1024
// TODO: Autotune/use better heuristics, improve speed more.
const int MAX_BLOCK_SIZE = 256;
static int getGradParamsNumThreads(int batchSize){
//warp per item in a batch, up to a maximum
return std::min(batchSize * C10_WARP_SIZE, MAX_BLOCK_SIZE);
}
template <typename T, typename AccT, typename IndexType, int kSize>
__global__ void spatialDepthwiseConvolutionUpdateOutput(
const THCDeviceTensor<T, 4> input,
THCDeviceTensor<T, 4> output,
const THCDeviceTensor<T, 4> weight,
const THCDeviceTensor<T, 1> bias,
bool biasEnabled,
IndexType totalElements,
const int outputChannels,
const int depthwiseMultiplier,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight)
{
const int KW_LIMIT = (kSize !=0) ? kSize : kernelWidth;
const int KH_LIMIT = (kSize !=0) ? kSize : kernelHeight;
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
//calculate n,c,h,w indices, replacing modulos by divide and multiply add,
//result is same as would be in the code below
//const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
//const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
//const int h = (linearIndex / outputWidth) % outputHeight;
//const int w = linearIndex % outputWidth;
int indtmp1 = linearIndex/outputWidth;
const int w = linearIndex - indtmp1 * outputWidth;
int indtmp2 = indtmp1/outputHeight;
const int h = indtmp1 - indtmp2 * outputHeight;
indtmp1 = indtmp2;
indtmp2 = indtmp1/outputChannels;
const int c = indtmp1 - indtmp2 * outputChannels;
const int n = indtmp2;
int inputChannel = c;
int inputChannels = outputChannels;
if (depthwiseMultiplier !=1) {
inputChannel /= depthwiseMultiplier;
inputChannels /= depthwiseMultiplier;
}
int weightOffset = c * kernelHeight * kernelWidth;
AccT value = biasEnabled ? ScalarConvert<T, AccT>::to(bias.data()[c]) : ScalarConvert<int, AccT>::to(0);
const IndexType offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (int kH = 0; kH < KH_LIMIT; ++kH) {
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (int kW = 0; kW < KW_LIMIT; ++kW) {
const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
const IndexType offset = offset0 + h_in * inputWidth + w_in;
value = THCNumerics<AccT>::add(
value,
THCNumerics<AccT>::mul(
ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
ScalarConvert<T, AccT>::to(input.data()[offset])));
}
++weightOffset;
}
}
output.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
}
}
template <typename T, typename AccT, typename IndexType, int kSize, int stride>
__global__ void spatialDepthwiseConvolutionUpdateGradInput(
const THCDeviceTensor<T, 4> gradOutput,
THCDeviceTensor<T, 4> gradInput,
const THCDeviceTensor<T, 4> weight,
IndexType totalElements,
const int inputChannels,
const int depthwiseMultiplier,
const int outputChannels,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight)
{
const int KW_LIMIT = (kSize !=0) ? kSize : kernelWidth;
const int KH_LIMIT = (kSize !=0) ? kSize : kernelHeight;
const int strideW = (stride !=0) ? stride : strideWidth;
const int strideH = (stride !=0) ? stride : strideHeight;
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
int indtmp1 = linearIndex/inputWidth;
const int w = linearIndex - indtmp1 * inputWidth;
int indtmp2 = indtmp1/inputHeight;
const int h = indtmp1 - indtmp2 * inputHeight;
indtmp1 = indtmp2;
indtmp2 = indtmp1/inputChannels;
const int c = indtmp1 - indtmp2 * inputChannels;
const int n = indtmp2;
AccT value = ScalarConvert<int, AccT>::to(0);
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
int och = (c * depthwiseMultiplier) + multiplier;
int weightOffset = och * kernelHeight * kernelWidth;
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (int kh = 0; kh < KH_LIMIT; ++kh) {
#ifdef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (int kw = 0; kw < KW_LIMIT; ++kw) {
int h_out = h + padHeight - kh * dilationHeight;
int w_out = w + padWidth - kw * dilationWidth;
if ((h_out % strideH == 0) && (w_out % strideW == 0)) {
h_out = h_out / strideH;
w_out = w_out / strideW;
if ((h_out >= 0) && (h_out < outputHeight)
&& (w_out >= 0) && (w_out < outputWidth)) {
const int offset = ((n * outputChannels + och) * outputHeight + h_out)
* outputWidth + w_out;
value = THCNumerics<AccT>::add(
value,
THCNumerics<AccT>::mul(
ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
ScalarConvert<T, AccT>::to(gradOutput.data()[offset])));
}
}
++weightOffset;
}
}
}
gradInput.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
}
}
template <typename T, typename AccT, typename IndexType>
__global__ void spatialDepthwiseConvolutionAccGradParameters(
const THCDeviceTensor<T, 4> gradOutput,
const THCDeviceTensor<T, 4> input,
THCDeviceTensor<T, 4> gradWeight,
const int batchSize,
const int inputChannels,
const int kernelChannels,
const int depthwiseMultiplier,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight)
{
const int channelStride = kernelWidth * kernelHeight;
// Have to use a statically typed Shared Memory pointer
SharedMem<AccT> smem;
// Each Block is responsible for accumulating over a permutation of
// (channels x kH x kW), use blockIdx to determine which one
int bidx = blockIdx.x;
int kW = bidx % kernelWidth;
int kH = (bidx / kernelWidth) % kernelHeight;
int ch = (bidx / channelStride);
// Need to calculate which input channel is associated with this filter
// channel
int inputCh = ch / depthwiseMultiplier;
AccT grad = ScalarConvert<float, AccT>::to(0.0);
const int laneId = threadIdx.x % C10_WARP_SIZE;
const int batch = threadIdx.x / C10_WARP_SIZE;
const int nwarps = blockDim.x / C10_WARP_SIZE;
const int imageElements = outputWidth * outputHeight;
// Use warp per item. In the original kernel, a threadblock was used to sum over NHW.
// Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the
// number of warps, a warp would loop over remaining batch items (e.g. if there are 8 warps,
// warp 0 would go over 0-8-16 etc image, warp 1 over 1-9-17 etc). Later in blockReduce,
// all the warps will be reduced anyway, thus the full reduction will be over NHW, like it
// should be. That allows to get rid of one modulo operation inside the loop (because n/batchIdx
// now does not have to be computed through modulo, you are just looping over it), and
// bring a nice speed-up.
for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){
// Warp-stride loop over elements in a batch item
for (IndexType idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) {
// Need to calculate the following: batch position, and offset into the gradOutput
// in height, and width. We can intuit the corresponding position in the input from
// the other parameters we have
int go_w_offset = idx % outputWidth;
int go_h_offset = (idx / outputWidth);
int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
int inputOffset = ((batchIdx * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
int outputOffset = ((batchIdx * kernelChannels + ch) * outputHeight ) * outputWidth + idx;
grad = THCNumerics<AccT>::add(
grad,
THCNumerics<AccT>::mul(
ScalarConvert<T, AccT>::to(input.data()[inputOffset]),
ScalarConvert<T, AccT>::to(gradOutput.data()[outputOffset])));
}
}
}
__syncthreads();
// At this point each thread in the block has a local gradient, which we need to
// accumulate prior to writing the global value
AccT *buf = smem.getPointer();
AccT tval = reduceBlock<AccT, ReduceAdd<AccT>>(
buf, blockDim.x, grad, ReduceAdd<AccT>(), ScalarConvert<float, AccT>::to(0));
// After reduction, first thread in the block has the gradient, so its responsible
// for writing it to gradWeight
if (threadIdx.x == 0) {
int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
gradWeight.data()[weightOffset] = ScalarConvert<AccT, T>::to(tval);
}
}
#include <THCUNN/generic/SpatialDepthwiseConvolution.cu>
#include <THC/THCGenerateFloatTypes.h>