forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
THCTensorTopK.cuh
144 lines (116 loc) · 4.98 KB
/
THCTensorTopK.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#ifndef THC_TENSOR_TOPK_CUH
#define THC_TENSOR_TOPK_CUH
#include <c10/macros/Macros.h>
#include <aten/src/ATen/native/cuda/SortingRadixSelect.cuh>
using namespace at::native;
template <typename T, typename IndexType, int Dim, bool Order>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void gatherTopK(TensorInfo<T, IndexType> input,
IndexType inputSliceSize,
IndexType outputSliceSize, // aka `k`
IndexType numInputSlices,
IndexType inputWithinSliceStride,
TensorInfo<T, IndexType> topK,
IndexType numTopKSlices,
IndexType topKWithinSliceStride,
TensorInfo<int64_t, IndexType> indices,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
#ifdef __HIP_PLATFORM_HCC__
__shared__ int smem[64];
#else
__shared__ int smem[32]; // one per each warp, up to warp limit
#endif
IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
IndexType sliceStartIndex =
IndexToOffset<T, IndexType, Dim>::get(slice, input);
IndexType topKSliceStartIndex =
IndexToOffset<T, IndexType, Dim>::get(slice, topK);
IndexType indicesSliceStartIndex =
IndexToOffset<int64_t, IndexType, Dim>::get(slice, indices);
T* inputSliceStart = &input.data[sliceStartIndex];
T* topKSliceStart = &topK.data[topKSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
T topKValue = ScalarConvert<int, T>::to(0);
radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType, Order>(
inputSliceStart, outputSliceSize,
inputSliceSize, inputWithinSliceStride,
smem, &topKValue);
// Every value that is strictly less/greater than `pattern`
// (depending on sort dir) in sorted int format is in the top-K.
// The top-K value itself might not be unique.
//
// Since there are a variable number of elements that we see that
// are within the top-k, we don't know at what index to write out
// the resulting values.
// In order to get this, we perform an exclusive prefix sum of
// `hasTopK`. This will return the resulting index into which we
// need to write the result, if a thread has a result.
// All threads need to participate in the loop and the prefix sum,
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
IndexType writeIndexStart = 0;
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
T v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
bool hasTopK;
if (Order) {
hasTopK = inRange && (THCNumerics<T>::gt(v, topKValue));
} else {
hasTopK = inRange && (THCNumerics<T>::lt(v, topKValue));
}
int index;
int carry;
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
if (hasTopK) {
int writeIndex = writeIndexStart + index;
assert(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i;
}
writeIndexStart += carry;
}
// We need to fill in the rest with actual == top-K values.
// The number that we need is outputSliceSize -
// writeIndexStart. There might be more than that number available,
// in which case we have to choose the first seen set. We do this
// via a prefix sum to calculate indices for writing results.
assert(outputSliceSize >= writeIndexStart);
IndexType topKRemaining = (outputSliceSize - writeIndexStart);
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
T v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
bool hasTopK = inRange && (THCNumerics<T>::eq(v, topKValue));
int index;
int carry;
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
if (hasTopK && index < topKRemaining) {
int writeIndex = writeIndexStart + index;
assert(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i;
}
if (carry >= topKRemaining) {
break;
}
topKRemaining -= carry;
writeIndexStart += carry;
}
}
#undef RADIX_BITS
#undef RADIX_SIZE
#undef RADIX_MASK
#endif // THC_TENSOR_TOPK_CUH