forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
input_metadata.cpp
204 lines (178 loc) · 6.31 KB
/
input_metadata.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#include <torch/csrc/autograd/input_metadata.h>
// TODO: we may be able to move some imports from input_metadata.h to here, but
// it seems that function.h transitively depends on some of them.
namespace torch::autograd {
namespace {
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
auto nested_size = input._nested_tensor_size();
return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
bool is_python_dispatch(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
}
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
return tensor.is_nested() && !is_python_dispatch(tensor);
}
} // namespace
InputMetadata::InputMetadata(
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass,
bool is_nested)
: options_{options},
shape_{std::move(input_shape)},
is_tensor_subclass_{is_tensor_subclass},
is_nested_{is_nested},
was_default_constructed_{false} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
InputMetadata::InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
compute_variant_shape(t),
is_python_dispatch(t),
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
TORCH_CHECK(
!is_nested_, "Zeros is not currently supported for nested tensors.")
return at::zeros_symint(shape_as_dim_vector(), options_);
}
at::Tensor InputMetadata::maybe_reduce(
const size_t i,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const {
auto fail = [&]() {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
};
// Nested tensor makes my brain explode, so I've just hard-coded the logic
// for this case, at risk of code duplication. This logic does NOT do the
// careful oblivious logic as seen below
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
::torch::autograd::is_cpp_nested_tensor(grad)) {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
} else {
fail();
}
} else {
return grad;
}
}
auto shape = shape_as_dim_vector();
auto desired = grad.sym_sizes();
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
fail();
}
bool needs_reduce = false;
for (const auto i : c10::irange(ndim)) {
const auto& size = shape[ndim - i - 1];
const auto& target = desired[target_dim - i - 1];
// The conditions here are written carefully so that we are able to
// infer deferred runtime asserts
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
// NB: we could short circuit this once needs_reduce is true but there's
// no point since the reduction function will guard on this anyway
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
needs_reduce = true;
}
} else {
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
fail();
}
}
}
if (ndim != target_dim) {
needs_reduce = true;
}
if (needs_reduce) {
return reduce_grad(grad);
} else {
return grad;
}
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
if (!is_nestedness_same(grad)) {
return false;
}
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
if (!maybe_expandable_to(grad)) {
return false;
}
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
// reduce_grad should only be called if is_expandable_to_shape returns true.
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
std::stringstream InputMetadata::incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss{};
ss << "invalid gradient at index " << index << " - got ";
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
}
ss << " but expected shape compatible with ";
if (is_cpp_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
}
return ss;
}
bool InputMetadata::is_cpp_nested_tensor() const {
bool ret = std::holds_alternative<at::Tensor>(shape_);
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
return ret;
}
c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
// Danger: not thread safe, caller must protect with lock
SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
return std::get<SymIntSmallVec>(shape_);
}
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
return (
grad.is_nested() == is_nested_ &&
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
}
at::Tensor InputMetadata::shape_as_tensor() const {
return std::get<at::Tensor>(shape_);
}
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
// This is the initial step to determine whether or not the tensor represented
// by input_metadata is expandable to grad based on is-nestedness information
// alone. If this function returns true, then is_expandable_to_shape will be
// called. We support the following 3 types of expansion:
bool grad_is_nested = grad.is_nested();
if (!is_nested_ && !grad_is_nested) {
// Normal case (no NestedTensors are involved)
// (1) plain Tensor -> plain Tensor
return true;
} else {
// (2) python NT -> python NT
// (3) plain Tensor -> python NT
return (
grad_is_nested && is_python_dispatch(grad) &&
(!is_nested_ || is_tensor_subclass_));
}
}
} // namespace torch::autograd