forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionalizeFallbackKernel.cpp
301 lines (276 loc) · 13.5 KB
/
FunctionalizeFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/InferSize.h>
#include <ATen/TensorUtils.h>
#include <torch/library.h>
#include <c10/util/irange.h>
#include <c10/util/strides.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/ATen.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_to_copy.h>
#include <ATen/ops/to_native.h>
#include <ATen/ops/lift.h>
#include <ATen/ops/lift_fresh.h>
#include <ATen/ops/lift_fresh_copy.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/as_strided.h>
#include <ATen/ops/as_strided_copy.h>
#include <ATen/ops/empty_strided_native.h>
#include <ATen/ops/_unsafe_view.h>
#include <utility>
#endif
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
TORCH_CHECK(
!schema.hasAnyAliasInfo(),
"Found a custom (non-ATen) operator that either mutates or its inputs: ",
op.operator_name().name, ".", op.operator_name().overload_name,
". Getting these operators to work with functionalization requires some extra work",
". For mutable ops you need to register a corresponding out-of-place variant of the op,",
" and you also need to register a Functionalization kernel that performs some boilerplate,",
" telling functionalization to map from the mutable op to the out-of-place op",
". See a more complete example of how to do this at ",
"https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.",
" Please file a GitHub issue if you run into any problems.");
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
auto any_functional_inputs = false;
auto any_tensor_inputs = false;
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
any_tensor_inputs = true;
const auto& t = ivalue.toTensor();
if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
any_functional_inputs = true;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
any_tensor_inputs = true;
auto tensors = ivalue.toTensorList();
if (at::functionalization::impl::isFunctionalTensor(tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isOptionalTensorList()) {
any_tensor_inputs = true;
auto opt_tensors = ivalue.toOptionalTensorList();
if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(opt_tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
}
// we should wrap the output if any inputs were wrapped,
// OR if we're hitting a factory function (with no tensor inputs)
auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
{
at::AutoDispatchSkipFunctionalize guard;
op.callBoxed(stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
const auto& t = ivalue.toTensor();
if (!t.defined()) continue;
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
auto tensors = ivalue.toTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
auto opt_tensors = ivalue.toOptionalTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
}
// resize_() is special because:
// - when we resize to a larger size, it acts as a mutation
// - when we resize to a smaller size, it acts as a view
// See Note [resize_ in Functionalization] for more dtails
static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) {
// First unwrap the tensor arguments
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
at::functionalization::impl::sync(self);
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
// Case 1: arguments are not functional tensors, so we no-op and redispatch.
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
self_.resize_(size, memory_format);
return self;
}
// Case 2: actually functionalize resize_()
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::resize(self_, size, memory_format);
}
auto itemsize = self.dtype().itemsize();
auto storage_offset = self.storage_offset();
auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
if (needs_resize_storage) {
// If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
// See Note[resize_() in functionalization pass]
auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
func_impl->maybe_replace_storage(tmp_output);
// See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
// So we don't need to treat the output of resize as view tensor.
return self;
}
// Otherwise, we know that we're resizing to a smaller size.
// resize_() is effectively a view operator.
// The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
}
);
at::functionalization::impl::mutate_view_meta(self, std::move(view_meta));
return self;
}
static at::Tensor lift_functionalize(const at::Tensor & self) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
at::AutoDispatchSkipFunctionalize guard;
auto out = at::lift(self);
return at::functionalization::impl::to_functional_tensor(out);
}
static at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
at::AutoDispatchSkipFunctionalize guard;
auto out = at::lift_fresh(self);
return at::functionalization::impl::to_functional_tensor(out);
}
static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
at::AutoDispatchSkipFunctionalize guard;
auto out = at::lift_fresh_copy(self);
return at::functionalization::impl::to_functional_tensor(out);
}
static bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
// If the target device is empty, then the output tensor should be on the same device as the input
auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
}
// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
// We should probably get rid of this though.
static at::Tensor _to_copy_functionalize(
const at::Tensor & self,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
// sync any pending updates
at::functionalization::impl::sync(self);
// pass the unwrapped tensor to the backend
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
at::AutoDispatchSkipFunctionalize guard;
auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
// Special case: if the Functionalize key is not in TLS, we assume that we're running
// on a lazy backend (LTC).
// In that case, if we're copying to a non-functionalize-enabled device,
// then the functionalization pass should "end". We need to sync any updates on the input
// tensor, but we shouldn't wrap the output.
if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
if (!device_opted_into_functionalization(self.device(), device)) {
return out;
}
}
return at::functionalization::impl::to_functional_tensor(out);
}
// Why is _unsafe_view special-cased here?
// Basically just to satisfy autograd's debug asserts.
// The situation:
// - _unsafe_view's autograd kernel has debug asserts to confirm
// that the input and output alias storage.
// - _unsafe_view's schema in native_functions.yaml
// does not contain alias annotations, so it advertises as non-aliasing.
// - functionalization will then treat _unsafe_view like a non-aliasing op.
// Specifically, autograd will redispatch to functionalization's
// boxed fallback kernel, which creates a new FunctionalTensorWrapper output
// that does **not** alias storage with the input, tripping the assert.
// The kernel written here just manually re-ifies the aliasing relationship.
//
// Another way to handle this would be to fix unsafe_view's alias annotations
// in native_functions.yaml, but I think this would be a pessimization.
// The idea with _unsafe_view is that you're guaranteed that the input
// is a temporary, and don't actually have to worry about propagating
// mutations between the input and output.
static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
return at::_unsafe_view_symint(self, size);
}
auto self_ = at::functionalization::impl::from_functional_tensor(self);
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_unsafe_view_symint(self_, size);
}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
// See Note [Propagating strides in the functionalization pass]
// (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
auto inferred_size = at::infer_size_dv(size, self.sym_numel());
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
return out;
}
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("resize_", TORCH_FN(resize__functionalization));
m.impl("lift", TORCH_FN(lift_functionalize));
m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
}