forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
FunctionOfAMatrixUtilsKernel.cu
114 lines (95 loc) · 3.29 KB
/
FunctionOfAMatrixUtilsKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/FunctionOfAMatrixUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
namespace at::native {
namespace {
template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
__global__ void _elemwise_kernel(int total_n_elems, func_t f) {
constexpr int total_work_block = n_threads * n_elems_per_thread;
int idx = total_work_block * blockIdx.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < n_elems_per_thread; ++i) {
if (idx < total_n_elems) {
f(idx);
idx += n_threads;
}
}
}
template <int n_threads, int n_elems_per_thread, typename func_t>
void _lauch_kernel(int total_n_elems, const func_t& f) {
TORCH_INTERNAL_ASSERT(
total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
);
dim3 block(n_threads);
constexpr int total_work_block = n_threads * n_elems_per_thread;
dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
auto stream = at::cuda::getCurrentCUDAStream();
_elemwise_kernel<n_threads, n_elems_per_thread, func_t>
<<<grid, block, 0, stream>>>(total_n_elems, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t>
void _compute_linear_combination_internal_kernel(
TensorIterator& iter,
int32_t in_stride,
int32_t coeff_stride,
int32_t num_summations
) {
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
_compute_linear_combination_internal_kernel<scalar_t>(
sub_iter, in_stride, coeff_stride, num_summations
);
}
return;
}
auto offset_calc = make_offset_calculator<3>(iter);
char* __restrict__ out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
char* __restrict__ in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
char* __restrict__ coeff_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
auto loop = [=]C10_DEVICE(int idx) {
auto offsets = offset_calc.get(idx);
auto* __restrict__ out_data = reinterpret_cast<scalar_t*>(
out_ptr + offsets[0]
);
auto* __restrict__ in_data = reinterpret_cast<scalar_t*>(
in_ptr + offsets[1]
);
using primitive_t = typename scalar_value_type<scalar_t>::type;
auto* __restrict__ coeff_data = reinterpret_cast<primitive_t*>(
coeff_ptr + offsets[2]
);
// perform summation
for (int32_t i = 0; i < num_summations; ++i) {
*out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
}
};
_lauch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
void _compute_linear_combination_cuda_kernel(
TensorIterator& iter,
int64_t in_stride,
int64_t coeff_stride,
int64_t num_summations
) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"_compute_linear_combination_cuda", [&] () {
_compute_linear_combination_internal_kernel<scalar_t>(
iter, in_stride, coeff_stride, num_summations
);
}
);
}
}
REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cuda_kernel);
} // namespace at::native