forked from kulinseth/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BatchNorm_miopen.cpp
234 lines (198 loc) · 8.22 KB
/
BatchNorm_miopen.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/miopen_batch_norm_native.h>
#include <ATen/ops/miopen_batch_norm_backward_native.h>
#endif
// TODO: Remove the condition on AT_ROCM_ENABLED entirely,
// don't build this file as part of CPU build.
#include <ATen/cuda/CUDAConfig.h>
#if !AT_ROCM_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
bool training, double exponential_average_factor, double epsilon) {
AT_ERROR("miopen_batch_norm: ATen not compiled with MIOpen support");
}
std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
const Tensor& input, const Tensor& grad_output, const Tensor& weight, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
double epsilon) {
AT_ERROR("miopen_batch_norm_backward: ATen not compiled with MIOpen support");
}
}} // namespace at::native
#else // AT_ROCM_ENABLED
#include <ATen/miopen/Descriptors.h>
#include <ATen/miopen/Types.h>
#include <ATen/miopen/Utils.h>
#include <ATen/TensorUtils.h>
namespace at { namespace native {
namespace {
Tensor expandScale(const Tensor& t, int64_t dim) {
std::vector<int64_t> size{ 1, t.numel() };
while (static_cast<int64_t>(size.size()) < dim) {
size.emplace_back(1);
}
return t.view(size);
}
} // namespace
std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
const Tensor& input_t, const Tensor& weight_t, const c10::optional<Tensor>& bias_t_opt, const c10::optional<Tensor>& running_mean_t_opt, const c10::optional<Tensor>& running_var_t_opt,
bool training, double exponential_average_factor, double epsilon)
{
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
const Tensor& bias_t = *bias_t_maybe_owned;
const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();});
const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();});
TensorArg input{ input_t, "input", 1 },
weight{ weight_t, "weight", 2 },
bias{ bias_t, "bias", 3 },
running_mean{ running_mean_t, "running_mean", 4 },
running_var{ running_var_t, "running_var", 5 };
CheckedFrom c = "miopen_batch_norm";
checkAllDefined(c, {input, weight, bias});
if (!training) {
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->scalar_type() != ScalarType::Half) {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
checkAllContiguous(c, {input, weight, bias, running_mean, running_var});
checkDimRange(c, input, 2, 6 /* exclusive */);
auto num_features = input->size(1);
for (auto t : {weight, bias, running_mean, running_var}) {
if (t->defined()) {
checkNumel(c, t, num_features);
}
}
miopenBatchNormMode_t mode;
if (input->dim() == 2) {
mode = miopenBNPerActivation;
} else {
mode = miopenBNSpatial;
}
auto output_t = at::empty(input->sizes(), input->options());
TensorArg output{ output_t, "output", 0 };
auto handle = getMiopenHandle();
auto dataType = getMiopenDataType(*input);
TensorDescriptor idesc{ *input, 4 }; // input descriptor
TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc.
Constant one(dataType, 1);
Constant zero(dataType, 0);
Tensor save_mean, save_var;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({ num_features }, weight_t.options());
save_var = at::empty({ num_features }, weight_t.options());
MIOPEN_CHECK(miopenBatchNormalizationForwardTraining(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
idesc.desc(), output->data_ptr(),
wdesc.desc(),
weight->data_ptr(),
bias->data_ptr(),
exponential_average_factor,
at::maybe_data_ptr(running_mean),
at::maybe_data_ptr(running_var),
epsilon,
save_mean.data_ptr(),
save_var.data_ptr()));
} else {
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
idesc.desc(), output->data_ptr(),
wdesc.desc(),
weight->data_ptr(),
bias->data_ptr(),
running_mean->data_ptr(),
running_var->data_ptr(),
epsilon));
}
// save_mean and save_var can be undefined
// If this causes problems, we can initialize them to empty tensors
// of the correct type
return std::tuple<Tensor, Tensor, Tensor>{output_t, save_mean, save_var};
}
std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
const Tensor& input_t,
const Tensor& grad_output_t,
const Tensor& weight_t,
// Unused: but we require them to be passed so that double backwards
// has access
const optional<Tensor>& running_mean_opt,
const optional<Tensor>& running_var_opt,
const optional<Tensor>& save_mean_t_opt,
const optional<Tensor>& save_var_t_opt,
double epsilon) {
// See [Note: hacky wrapper removal for optional tensor]
const Tensor& running_mean =
c10::value_or_else(running_mean_opt, [] { return Tensor(); });
const Tensor& running_var =
c10::value_or_else(running_var_opt, [] { return Tensor(); });
const Tensor& save_mean_t =
c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
const Tensor& save_var_t =
c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
TensorArg input{ input_t, "input", 1 },
grad_output{ grad_output_t, "grad_output", 2 },
weight{ weight_t, "weight", 3 },
save_mean{ save_mean_t, "save_mean", 4 },
save_var{ save_var_t, "save_var", 5 };
CheckedFrom c = "miopen_batch_norm_backward";
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
if (input->scalar_type() == ScalarType::Half) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {input, grad_output});
checkAllSameType(c, {weight, save_mean, save_var});
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
checkDimRange(c, input, 2, 6 /* exclusive */);
checkSameSize(c, input, grad_output);
auto num_features = input->size(1);
for (auto t : {weight, save_mean, save_var}) {
checkNumel(c, t, num_features);
}
miopenBatchNormMode_t mode;
if (input->dim() == 2) {
mode = miopenBNPerActivation;
} else {
mode = miopenBNSpatial;
}
auto grad_input_t = at::empty(input->sizes(), input->options());
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
auto handle = getMiopenHandle();
auto dataType = getMiopenDataType(*input);
TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor
TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc.
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &zero,
idesc.desc(), input->data_ptr(),
idesc.desc(), grad_output->data_ptr(),
idesc.desc(), grad_input_t.data_ptr(),
wdesc.desc(), weight->data_ptr(),
grad_weight_t.data_ptr(),
grad_bias_t.data_ptr(),
epsilon,
save_mean->data_ptr(),
save_var->data_ptr()));
return std::tuple<Tensor,Tensor,Tensor>{grad_input_t, grad_weight_t, grad_bias_t};
}
}} // namespace native
#endif