forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
jit.rst
1764 lines (1269 loc) · 52.8 KB
/
jit.rst
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
TorchScript
===========
.. toctree::
:maxdepth: 1
:caption: Builtin Functions
:hidden:
torch.jit.supported_ops <jit_builtin_functions>
.. contents:: :local:
.. automodule:: torch.jit
.. currentmodule:: torch.jit
TorchScript is a way to create serializable and optimizable models from PyTorch code.
Any TorchScript program can be saved from a Python
process and loaded in a process where there is no Python dependency.
We provide tools to incrementally transition a model from a pure Python program
to a TorchScript program that can be run independently from Python, such as in a standalone C++ program.
This makes it possible to train models in PyTorch using familiar tools in Python and then export
the model via TorchScript to a production environment where Python programs may be disadvantageous
for performance and multi-threading reasons.
For a gentle introduction to TorchScript, see the `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ tutorial.
For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the
`Loading a PyTorch Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_ tutorial.
Creating TorchScript Code
--------------------------
.. autoclass:: ScriptModule()
:members:
.. autoclass:: ScriptFunction()
.. autofunction:: script(obj)
.. autofunction:: trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)
.. autofunction:: trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)
.. autofunction:: save
.. autofunction:: load
Mixing Tracing and Scripting
----------------------------
In many cases either tracing or scripting is an easier approach for converting a model to TorchScript.
Tracing and scripting can be composed to suit the particular requirements
of a part of a model.
Scripted functions can call traced functions. This is particularly useful when you need
to use control-flow around a simple feed-forward model. For instance the beam search
of a sequence to sequence model will typically be written in script but can call an
encoder module generated using tracing.
.. testsetup::
# These are hidden from the docs, but these are necessary for `doctest`
# since the `inspect` module doesn't play nicely with the execution
# environment for `doctest`
import torch
original_script = torch.jit.script
def script_wrapper(obj, *args, **kwargs):
obj.__module__ = 'FakeMod'
return original_script(obj, *args, **kwargs)
torch.jit.script = script_wrapper
original_trace = torch.jit.trace
def trace_wrapper(obj, *args, **kwargs):
obj.__module__ = 'FakeMod'
return original_trace(obj, *args, **kwargs)
torch.jit.trace = trace_wrapper
Example (calling a traced function in script):
.. testcode::
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
Traced functions can call script functions. This is useful when a small part of
a model requires some control-flow even though most of the model is just a feed-forward
network. Control-flow inside of a script function called by a traced function is
preserved correctly.
Example (calling a script function in a traced function):
.. testcode::
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
This composition also works for ``nn.Module``\s as well, where it can be used to generate
a submodule using tracing that can be called from the methods of a script module.
Example (using a traced module):
.. testcode::
:skipif: torchvision is None
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript Language Reference
-------------------------------
TorchScript is a statically typed subset of Python that can either be written directly (using
the :func:`@torch.jit.script <torch.jit.script>` decorator) or generated automatically from Python code via
tracing. When using tracing, code is automatically converted into this subset of
Python by recording only the actual operators on tensors and simply executing and
discarding the other surrounding Python code.
When writing TorchScript directly using ``@torch.jit.script`` decorator, the programmer must
only use the subset of Python supported in TorchScript. This section documents
what is supported in TorchScript as if it were a language reference for a stand
alone language. Any features of Python not mentioned in this reference are not
part of TorchScript. See `Builtin Functions`_ for a complete reference of available
Pytorch tensor methods, modules, and functions.
As a subset of Python, any valid TorchScript function is also a valid Python
function. This makes it possible to `disable TorchScript`_ and debug the
function using standard Python tools like ``pdb``. The reverse is not true: there
are many valid Python programs that are not valid TorchScript programs.
Instead, TorchScript focuses specifically on the features of Python that are
needed to represent neural network models in PyTorch.
.. _types:
.. _supported type:
Types
~~~~~
The largest difference between TorchScript and the full Python language is that
TorchScript only supports a small set of types that are needed to express neural
net models. In particular, TorchScript supports:
.. csv-table::
:header: "Type", "Description"
"``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend"
"``Tuple[T0, T1, ...]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)"
"``bool``", "A boolean value"
"``int``", "A scalar integer"
"``float``", "A scalar floating point number"
"``str``", "A string"
"``List[T]``", "A list of which all members are type ``T``"
"``Optional[T]``", "A value which is either None or type ``T``"
"``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types."
"``T``", "A `TorchScript Class`_"
"``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple <collections.namedtuple>` tuple type"
Unlike Python, each variable in TorchScript function must have a single static type.
This makes it easier to optimize TorchScript functions.
Example (a type mismatch)
.. testcode::
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
.. testoutput::
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~... <--- HERE
r = torch.rand(1)
else:
and was used here:
else:
r = 4
return r
~ <--- HERE
...
Unsupported Typing Constructs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript does not support all features and types of the :mod:`typing` module. Some of these
are more fundamental things that are unlikely to be added in the future while others
may be added if there is enough user demand to make it a priority.
These types and features from the :mod:`typing` module are unavailble in TorchScript.
.. csv-table::
:header: "Item", "Description"
":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released"
":any:`typing.NoReturn`", "Not implemented"
":any:`typing.Union`", "Unlikely to be implemented (however :any:`typing.Optional` is supported)"
":any:`typing.Callable`", "Not implemented"
":any:`typing.Literal`", "Not implemented"
":any:`typing.ClassVar`", "Not implemented"
":any:`typing.Final`", "This is supported for :any:`module attributes <Module Attributes>` class attribute annotations but not for functions"
":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used"
":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released"
"Type aliases", "Not implemented"
"Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not"
"NewType", "Unlikely to be implemented"
"Generics", "Unlikely to be implemented"
Any other functionality from the :any:`typing` module not explitily listed in this documentation is unsupported.
Default Types
^^^^^^^^^^^^^
By default, all parameters to a TorchScript function are assumed to be Tensor.
To specify that an argument to a TorchScript function is another type, it is possible to use
MyPy-style type annotations using the types listed above.
.. testcode::
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
.. testoutput::
:hide:
...
.. note::
It is also possible to annotate types with Python 3 type hints from the
``typing`` module.
.. testcode::
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
.. testoutput::
:hide:
...
In our examples, we use comment-based type hints to ensure Python 2
compatibility as well.
An empty list is assumed to be ``List[Tensor]`` and empty dicts
``Dict[str, Tensor]``. To instantiate an empty list or dict of other types,
use `Python 3 type hints`_. If you are on Python 2, you can use ``torch.jit.annotate``.
Example (type annotations for Python 3):
.. testcode::
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super(EmptyDataStructures, self).__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
Example (``torch.jit.annotate`` for Python 2):
.. testcode::
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super(EmptyDataStructures, self).__init__()
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]
# This annotates the list to be a `List[Tuple[int, float]]`
my_list = torch.jit.annotate(List[Tuple[int, float]], [])
for i in range(10):
my_list.append((i, float(x.item())))
my_dict = torch.jit.annotate(Dict[str, int], {})
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
Optional Type Refinement
^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript will refine the type of a variable of type ``Optional[T]`` when
a comparison to ``None`` is made inside the conditional of an if-statement or checked in an ``assert``.
The compiler can reason about multiple ``None`` checks that are combined with
``and``, ``or``, and ``not``. Refinement will also occur for else blocks of if-statements
that are not explicitly written.
The ``None`` check must be within the if-statement's condition; assigning
a ``None`` check to a variable and using it in the if-statement's condition will
not refine the types of variables in the check.
Only local variables will be refined, an attribute like ``self.x`` will not and must assigned to
a local variable to be refined.
Example (refining types on parameters and locals):
.. testcode::
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super(M, self).__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
.. _TorchScript Class:
.. _TorchScript Classes:
TorchScript Classes
^^^^^^^^^^^^^^^^^^^
Python classes can be used in TorchScript if they are annotated with :func:`@torch.jit.script <torch.jit.script>`,
similar to how you would declare a TorchScript function:
.. testcode::
:skipif: True # TODO: fix the source file resolving so this can be tested
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
This subset is restricted:
* All functions must be valid TorchScript functions (including ``__init__()``).
* Classes must be new-style classes, as we use ``__new__()`` to construct them with pybind11.
* TorchScript classes are statically typed. Members can only be declared by assigning to
self in the ``__init__()`` method.
For example, assigning to ``self`` outside of the ``__init__()`` method: ::
@torch.jit.script
class Foo:
def assign_x(self):
self.x = torch.rand(2, 3)
Will result in: ::
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
* No expressions except method definitions are allowed in the body of the class.
* No support for inheritance or any other polymorphism strategy, except for inheriting
from ``object`` to specify a new-style class.
After a class is defined, it can be used in both TorchScript and Python interchangeably
like any other TorchScript type:
::
# Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
Named Tuples
^^^^^^^^^^^^
Types produced by :func:`collections.namedtuple <collections.namedtuple>` can be used in TorchScript.
.. testcode::
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
.. testoutput::
:hide:
...
.. _jit_iterables:
Iterables
^^^^^^^^^
Some functions (for example, :any:`zip` and :any:`enumerate`) can only operate on iterable types.
Iterable types in TorchScript include ``Tensor``\s, lists, tuples, dictionaries, strings,
:any:`torch.nn.ModuleList` and :any:`torch.nn.ModuleDict`.
Expressions
~~~~~~~~~~~
The following Python Expressions are supported.
Literals
^^^^^^^^
::
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
List Construction
"""""""""""""""""
An empty list is assumed have type ``List[Tensor]``.
The types of other list literals are derived from the type of the members.
See `Default Types`_ for more details.
::
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
Tuple Construction
""""""""""""""""""
::
(3, 4)
(3,)
Dict Construction
"""""""""""""""""
An empty dict is assumed have type ``Dict[str, Tensor]``.
The types of other dict literals are derived from the type of the members.
See `Default Types`_ for more details.
::
{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}
Variables
^^^^^^^^^
See `Variable Resolution`_ for how variables are resolved.
::
my_variable_name
Arithmetic Operators
^^^^^^^^^^^^^^^^^^^^
::
a + b
a - b
a * b
a / b
a ^ b
a @ b
Comparison Operators
^^^^^^^^^^^^^^^^^^^^
::
a == b
a != b
a < b
a > b
a <= b
a >= b
Logical Operators
^^^^^^^^^^^^^^^^^
::
a and b
a or b
not b
Subscripts and Slicing
^^^^^^^^^^^^^^^^^^^^^^
::
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
Function Calls
^^^^^^^^^^^^^^
Calls to `builtin functions`_
::
torch.rand(3, dtype=torch.int)
Calls to other script functions:
.. testcode::
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
Method Calls
^^^^^^^^^^^^
Calls to methods of builtin types like tensor: ``x.mm(y)``
On modules, methods must be compiled before they can be called. The TorchScript
compiler recursively compiles methods it sees when compiling other methods. By default,
compilation starts on the ``forward`` method. Any methods called by ``forward`` will
be compiled, and any methods called by those methods, and so on. To start compilation at
a method other than ``forward``, use the :func:`@torch.jit.export <torch.jit.export>` decorator
(``forward`` implicitly is marked ``@torch.jit.export``).
Calling a submodule directly (e.g. ``self.resnet(input)``) is equivalent to
calling its ``forward`` method (e.g. ``self.resnet.forward(input)``).
.. testcode::
:skipif: torchvision is None
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
Ternary Expressions
^^^^^^^^^^^^^^^^^^^
::
x if x > y else y
Casts
^^^^^
::
float(ten)
int(3.5)
bool(ten)
str(2)``
Accessing Module Parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
self.my_parameter
self.my_submodule.my_parameter
Statements
~~~~~~~~~~
TorchScript supports the following types of statements:
Simple Assignments
^^^^^^^^^^^^^^^^^^
::
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
Pattern Matching Assignments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
a, b = tuple_or_list
a, b, *c = a_tuple
Multiple Assignments
::
a = b, c = tup
Print Statements
^^^^^^^^^^^^^^^^
::
print("the result of an add:", a + b)
If Statements
^^^^^^^^^^^^^
::
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
In addition to bools, floats, ints, and Tensors can be used in a conditional
and will be implicitly casted to a boolean.
While Loops
^^^^^^^^^^^
::
a = 0
while a < 4:
print(a)
a += 1
For loops with range
^^^^^^^^^^^^^^^^^^^^
::
x = 0
for i in range(10):
x *= i
For loops over tuples
^^^^^^^^^^^^^^^^^^^^^
These unroll the loop, generating a body for
each member of the tuple. The body must type-check correctly for each member.
::
tup = (3, torch.rand(4))
for x in tup:
print(x)
For loops over constant nn.ModuleList
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To use a ``nn.ModuleList`` inside a compiled method, it must be marked
constant by adding the name of the attribute to the ``__constants__``
list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the
loop at compile time, with each member of the constant module list.
.. testcode::
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super(MyModule, self).__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
Break and Continue
^^^^^^^^^^^^^^^^^^
::
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
Return
^^^^^^
::
return a, b
Variable Resolution
~~~~~~~~~~~~~~~~~~~
TorchScript supports a subset of Python's variable resolution (i.e. scoping)
rules. Local variables behave the same as in Python, except for the restriction
that a variable must have the same type along all paths through a function.
If a variable has a different type on different branches of an if statement, it
is an error to use it after the end of the if statement.
Similarly, a variable is not allowed to be used if it is only *defined* along some
paths through the function.
Example:
.. testcode::
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
.. testoutput::
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~... <--- HERE
y = 4
print(y)
...
Non-local variables are resolved to Python values at compile time when the
function is defined. These values are then converted into TorchScript values using
the rules described in `Use of Python Values`_.
Use of Python Values
~~~~~~~~~~~~~~~~~~~~
To make writing TorchScript more convenient, we allow script code to refer
to Python values in the surrounding scope. For instance, any time there is a
reference to ``torch``, the TorchScript compiler is actually resolving it to the
``torch`` Python module when the function is declared. These Python values are
not a first class part of TorchScript. Instead they are de-sugared at compile-time
into the primitive types that TorchScript supports. This depends
on the dynamic type of the Python valued referenced when compilation occurs.
This section describes the rules that are used when accessing Python values in TorchScript.
Functions
^^^^^^^^^
TorchScript can call Python functions. This functionality is very useful when
incrementally converting a model to TorchScript. The model can be moved function-by-function
to TorchScript, leaving calls to Python functions in place. This way you can incrementally
check the correctness of the model as you go.
.. autofunction:: ignore
.. autofunction:: unused
.. autofunction:: is_scripting
Attribute Lookup On Python Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript can lookup attributes on modules. `Builtin functions`_ like ``torch.add``
are accessed this way. This allows TorchScript to call functions defined in
other modules.
.. _constant:
Python-defined Constants
^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript also provides a way to use constants that are defined in Python.
These can be used to hard-code hyper-parameters into the function, or to
define universal constants. There are two ways of specifying that a Python
value should be treated as a constant.
1. Values looked up as attributes of a module are assumed to be constant:
.. testcode::
import math
import torch
@torch.jit.script
def fn():
return math.pi
2. Attributes of a ScriptModule can be marked constant by annotating them with ``Final[T]``
::
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super(Foo, self).__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
Supported constant Python types are
* ``int``
* ``float``
* ``bool``
* ``torch.device``
* ``torch.layout``
* ``torch.dtype``
* tuples containing supported types
* ``torch.nn.ModuleList`` which can be used in a TorchScript for loop
.. note::
If you are on Python 2, you can mark an attribute as a constant by adding
its name to the ``__constants__`` property of the class:
.. testcode::
import torch
import torch.nn as nn
class Foo(nn.Module):
__constants__ = ['a']
def __init__(self):
super(Foo, self).__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
|
.. _module attributes:
Module Attributes