forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
_overrides.py
280 lines (227 loc) · 11 KB
/
_overrides.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
Python implementation of __torch_function__
While most of the torch API and handling for __torch_function__ happens
at the C++ level, some of the torch API is written in Python so we need
python-level handling for __torch_function__ overrides as well. The main
developer-facing functionality in this file is the
torch_function_dispatch decorator. This function can be applied to
python functions in the torch.functional module to enable
__torch_function__ overrides for that function. See the examples in the
docstrings for torch_function_dispatch for details.
NOTE: heavily inspired by NumPy's ``__array_function__`` (see:
https://github.com/pytorch/pytorch/issues/24015 and
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
)
"""
import functools
import textwrap
from . import _six
if _six.PY3:
from inspect import getfullargspec
import collections
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def getargspec(func):
spec = getfullargspec(func)
return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults)
else:
from inspect import getargspec
from .tensor import Tensor
_TENSOR_ONLY = [Tensor]
def _get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations,
storing references to the arguments and their types in overloaded_args and
overloaded_types in order of calling precedence. Only distinct types are
considered. If a type is a subclass of another type it will have higher
precedence, otherwise the precedence order is the same as the order of
arguments in relevant_args, that is, from left-to-right in the argument list.
The precedence-determining algorithm implemented in this function is
described in `NEP-0018`_.
See torch::append_overloaded_arg for the equivalent function in the C++
implementation.
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __torch_function__
methods.
Returns
-------
overloaded_types : collection of types
Types of arguments from relevant_args with __torch_function__ methods.
overloaded_args : list
Arguments from relevant_args on which to call __torch_function__
methods, in the order in which they should be called.
.. _NEP-0018:
https://numpy.org/neps/nep-0018-array-function-protocol.html
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
overloaded_types.append(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
else:
overloaded_types = [arg_type]
overloaded_args = [arg]
return overloaded_types, overloaded_args
def _implement_torch_function(
implementation, public_api, relevant_args, args, kwargs):
"""Implement a function with checks for __torch_function__ overrides.
See torch::autograd::handle_torch_function for the equivalent of this
function in the C++ implementation.
Arguments
---------
implementation : function
Function that implements the operation on ``torch.Tensor`` without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
Function exposed by the public torch API originally called like
``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __torch_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : tuple
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling `implementation()` or an `__torch_function__`
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
"""
# Check for __torch_function__ methods.
types, overloaded_args = _get_overloaded_types_and_args(relevant_args)
# Short-cut for common cases: no overload or only Tensor overload
# (directly or with subclasses that do not override __torch_function__).
if not overloaded_args or types == _TENSOR_ONLY:
return implementation(*args, **kwargs)
# Call overrides
for overloaded_arg in overloaded_args:
# Use `public_api` instead of `implementation` so __torch_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__torch_function__(public_api, args, kwargs)
if result is not NotImplemented:
return result
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
'__torch_function__: {}'
.format(func_name, list(map(type, overloaded_args))))
def _verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = getargspec(implementation)
dispatcher_spec = getargspec(dispatcher)
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
_wrapped_func_source = textwrap.dedent("""
@functools.wraps(implementation)
def {name}(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return implement_torch_function(
implementation, {name}, relevant_args, args, kwargs)
""")
def torch_function_dispatch(dispatcher, module=None, verify=True):
"""Decorator for adding dispatch with the __torch_function__ protocol.
If you define a function in Python and would like to permit user-defined
tensor-like types to override it using __torch_function__, please apply this
decorator on this function together with a custom dispatcher that indicates
which arguments should be checked for the presence of __torch_function__.
Suppose we'd like to apply this function to torch.frob, which has the
following definition:
def frob(input, bias, option=None):
return input + bias
We'd need to define a dispatcher for frob that has the same signature and
returns the elements of the signature that should be checked for
`__torch_function__`. If any of the arguments has a `__torch_function__`
attribute, that function will be called to handle custom dispatch. Assuming
that `bias` can be a tensor-like, our dispatcher would look like:
def _frob_dispatcher(input, bias, option=None):
return (input, bias)
The dispatcher must return an iterable, so return a single-element tuple if
only one argument should be checked. We would then modify the original
definition for torch.frob to look like:
@torch_function_dispatch(_frob_dispatcher)
def frob(input, bias, option=None):
return input + bias
See ``torch/functional.py`` for more usage examples.
Parameters
----------
dispatcher : callable
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__torch_function__``.
module : str, optional
``__module__`` attribute to set on new function, e.g.,
``module='torch'``. By default, module is copied from the decorated
function.
verify : bool, optional
If True, verify the that the signature of the dispatcher and decorated
function signatures match exactly: all required and optional arguments
should appear in order with the same names, but the default values for
all optional arguments should be ``None``. Only disable verification
if the dispatcher's signature needs to deviate for some particular
reason, e.g., because the function has a signature like
``func(*args, **kwargs)``.
Returns
-------
dispatcher : callable
Function suitable for decorating the implementation of a NumPy
function.
Notes
-----
The dispatcher should normally return a tuple containing all input
arguments that may have a ``__torch_function__`` attribute.
In some cases where that's not easily possible, e.g. ``torch.cat``, it is
also valid (if a little slower) to make the dispatcher function a generator
(i.e. use ``yield`` to return arguments one by one).
"""
def decorator(implementation):
if verify:
_verify_matching_signatures(implementation, dispatcher)
# Equivalently, we could define this function directly instead of using
# exec. This version has the advantage of giving the helper function a
# more interpretable name. Otherwise, the original function does not
# show up at all in many cases, e.g., if it's written in C++ or if the
# dispatcher gets an invalid keyword argument.
source = _wrapped_func_source.format(name=implementation.__name__)
source_object = compile(
source, filename='<__torch_function__ internals>', mode='exec')
scope = {
'implementation': implementation,
'dispatcher': dispatcher,
'functools': functools,
'implement_torch_function': _implement_torch_function,
}
_six.exec_(source_object, scope)
public_api = scope[implementation.__name__]
if module is not None:
public_api.__module__ = module
public_api._implementation = implementation
return public_api
return decorator