forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
UpSampleNearest1d.cu
231 lines (187 loc) · 7.18 KB
/
UpSampleNearest1d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/cuda/UpSample.cuh>
namespace at {
namespace native {
namespace {
#define MAX_THREADS 512
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest1d_out_frame(
const scalar_t* input,
size_t dim_b,
size_t dim_c,
size_t src_dim_w,
size_t dst_dim_w,
scalar_t* output,
float scale_factor) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_w)
return;
int c = (dst_idx / dst_dim_w) % dim_c;
int dst_x = dst_idx % dst_dim_w;
int src_x = nearest_neighbor_compute_source_index(scale_factor, dst_x, src_dim_w);
int src_idx = c * src_dim_w + src_x;
int src_stride = dim_c * src_dim_w;
int dst_stride = dim_c * dst_dim_w;
for (int b = 0; b < dim_b; b++) {
output[dst_idx] = input[src_idx];
src_idx += src_stride;
dst_idx += dst_stride;
}
}
// Backward operation
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest1d_backward_out_frame(
const scalar_t* grad_o,
size_t dim_b,
size_t dim_c,
size_t src_dim_w,
size_t dst_dim_w,
scalar_t* grad_i,
float scale_factor) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_w)
return;
int c = (dst_idx / (dst_dim_w)) % dim_c;
int dst_x = dst_idx % dst_dim_w;
int src_x = nearest_neighbor_compute_source_index(scale_factor, dst_x, src_dim_w);
int src_x_up = nearest_neighbor_compute_source_index(scale_factor, dst_x+1, src_dim_w+1);
for (int b = 0; b < dim_b; b++) {
accscalar_t grad = 0;
int src_idx = b * dim_c * src_dim_w + c * src_dim_w + src_x;
for (int x = src_x; x < src_x_up; x++) {
grad += grad_o[src_idx++];
}
grad_i[dst_idx] = grad;
dst_idx += dim_c * dst_dim_w;
}
}
static void upsample_nearest1d_out_cuda_template(
Tensor& output,
const Tensor& input_,
IntArrayRef output_size,
double scales_1) {
TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_nearest1d_out_cuda", {input_arg, output_arg});
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
int output_width = output_size[0];
int nbatch = input_.size(0);
int channels = input_.size(1);
int input_width = input_.size(2);
upsample_1d_shape_check(
input_, Tensor(), nbatch, channels, input_width, output_width);
AT_ASSERT(input_width > 0 && output_width > 0);
Tensor input = input_.contiguous();
output.resize_({input.size(0), input.size(1), output_width});
// upsample_1d_shape_check makes sure `nbatch != 0`
unsigned int n = output.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{cuda::ATenCeilDiv(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_nearest1d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = input.data_ptr<scalar_t>();
auto odata = output.data_ptr<scalar_t>();
const float scale_factor = compute_scales_value<float>(scales_1, input_width, output_width);
upsample_nearest1d_out_frame<scalar_t><<<gdim, bdim, 0, stream>>>(
idata, nbatch, channels, input_width, output_width, odata, scale_factor);
});
AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_nearest1d_backward_out_cuda_template(
Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
double scales_1) {
TensorArg grad_input_arg{grad_input, "grad_input", 1},
grad_output_arg{grad_output_, "grad_output_", 2};
checkAllSameGPU(
"upsample_nearest1d_backward_out_cuda_template",
{grad_output_arg, grad_input_arg});
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 3,
"It is expected input_size equals to 3, but got size ",
input_size.size());
int output_width = output_size[0];
int nbatch = input_size[0];
int channels = input_size[1];
int input_width = input_size[2];
upsample_1d_shape_check(
Tensor(), grad_output_, nbatch, channels, input_width, output_width);
Tensor grad_output = grad_output_.contiguous();
grad_input.resize_({nbatch, channels, input_width});
// upsample_1d_shape_check makes sure `nbatch != 0`
unsigned int n = grad_input.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{cuda::ATenCeilDiv(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = grad_input.data_ptr<scalar_t>();
auto odata = grad_output.data_ptr<scalar_t>();
const float scale_factor = compute_scales_value_backwards<float>(scales_1, output_width, input_width);
upsample_nearest1d_backward_out_frame<scalar_t, accscalar_t>
<<<gdim, bdim, 0, stream>>>(
odata, nbatch, channels, output_width, input_width, idata, scale_factor);
});
AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
Tensor& upsample_nearest1d_out_cuda(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
double scales_1) {
upsample_nearest1d_out_cuda_template(output, input, output_size, scales_1);
return output;
}
Tensor upsample_nearest1d_cuda(const Tensor& input, IntArrayRef output_size, double scales_1) {
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_nearest1d_out_cuda_template(output, input, output_size, scales_1);
return output;
}
Tensor& upsample_nearest1d_backward_out_cuda(
Tensor& grad_input,
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
double scales_1) {
upsample_nearest1d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, scales_1);
return grad_input;
}
Tensor upsample_nearest1d_backward_cuda(
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
double scales_1) {
Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_nearest1d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, scales_1);
return grad_input;
}
} // namespace native
} // namespace at