Skip to content

Commit

Permalink
[dynamo, docs] update dynamo backend registration docs (pytorch#114820)
Browse files Browse the repository at this point in the history
Update docs to reflect current backend registration API. Add `lookup_backend` to root `dynamo` module.

Pull Request resolved: pytorch#114820
Approved by: https://github.com/eellison
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Nov 30, 2023
1 parent 1f845d5 commit 38ae17d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 93 deletions.
184 changes: 92 additions & 92 deletions docs/source/torch.compiler_custom_backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ You can register your backend using the ``register_backend`` decorator, for exam

.. code-block:: python
from torch._dynamo.optimizations import register_backend
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
Expand Down Expand Up @@ -112,7 +112,7 @@ For example,

.. code-block:: python
from torch._dynamo.optimizations.training import aot_autograd
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
Expand All @@ -138,107 +138,107 @@ For example:

.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
Running the above example produces the following output:

::

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}

This works for ``torch.nn.Module`` as well as shown below:

.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
Let’s take a look at one more example with control flow:

.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
Running this example produces the following output:

::

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}

The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}

The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.

Speedy Backend
^^^^^^^^^^^^^^
Expand All @@ -249,17 +249,17 @@ with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.ji

.. code-block:: python
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
And then you should be able to optimize any existing code with:

.. code-block:: python
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
Composable Backends
^^^^^^^^^^^^^^^^^^^
Expand All @@ -271,17 +271,17 @@ together with the following code:

.. code-block:: python
from torch._dynamo.optimizations import BACKENDS
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = BACKENDS["inductor"](gm, example_inputs)
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .backends.registry import list_backends, lookup_backend, register_backend
from .code_context import code_context
from .convert_frame import replay
from .decorators import (
Expand Down Expand Up @@ -51,6 +51,7 @@
"is_compiling",
"register_backend",
"list_backends",
"lookup_backend",
]

if torch.manual_seed is torch.random.manual_seed:
Expand Down

0 comments on commit 38ae17d

Please sign in to comment.