diff --git a/docs/source/torch.compiler_custom_backends.rst b/docs/source/torch.compiler_custom_backends.rst index db980d7b3e833..47b749fcf83ba 100644 --- a/docs/source/torch.compiler_custom_backends.rst +++ b/docs/source/torch.compiler_custom_backends.rst @@ -45,7 +45,7 @@ You can register your backend using the ``register_backend`` decorator, for exam .. code-block:: python - from torch._dynamo.optimizations import register_backend + from torch._dynamo import register_backend @register_backend def my_compiler(gm, example_inputs): @@ -112,7 +112,7 @@ For example, .. code-block:: python - from torch._dynamo.optimizations.training import aot_autograd + from torch._dynamo.backends.common import aot_autograd from functorch.compile import make_boxed_func def my_compiler(gm, example_inputs): @@ -138,107 +138,107 @@ For example: .. code-block:: python - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - @torch.compile(backend=my_compiler) - def fn(x, y): - a = torch.cos(x) - b = torch.sin(y) - return a + b - fn(torch.randn(10), torch.randn(10)) + from typing import List + import torch + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @torch.compile(backend=my_compiler) + def fn(x, y): + a = torch.cos(x) + b = torch.sin(y) + return a + b + fn(torch.randn(10), torch.randn(10)) Running the above example produces the following output: :: - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ------------------------------------------------------ ---------- -------- - placeholder x x () {} - placeholder y y () {} - call_function cos (x,) {} - call_function sin (y,) {} - call_function add (cos, sin) {} - output output output ((add,),) {} + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ------------------------------------------------------ ---------- -------- + placeholder x x () {} + placeholder y y () {} + call_function cos (x,) {} + call_function sin (y,) {} + call_function add (cos, sin) {} + output output output ((add,),) {} This works for ``torch.nn.Module`` as well as shown below: .. code-block:: python - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - def forward(self, x): - return self.relu(torch.cos(x)) - mod = MockModule() - optimized_mod = torch.compile(mod, backend=my_compiler) - optimized_mod(torch.randn(10)) + from typing import List + import torch + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + def forward(self, x): + return self.relu(torch.cos(x)) + mod = MockModule() + optimized_mod = torch.compile(mod, backend=my_compiler) + optimized_mod(torch.randn(10)) Let’s take a look at one more example with control flow: .. code-block:: python - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - @torch.compile(backend=my_compiler) - def toy_example(a, b): - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - for _ in range(100): - toy_example(torch.randn(10), torch.randn(10)) + from typing import List + import torch + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @torch.compile(backend=my_compiler) + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) Running this example produces the following output: :: - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------- ------------------------------------------------------ ---------------- -------- - placeholder a a () {} - placeholder b b () {} - call_function abs_1 (a,) {} - call_function add (abs_1, 1) {} - call_function truediv (a, add) {} - call_method sum_1 sum (b,) {} - call_function lt (sum_1, 0) {} - output output output ((truediv, lt),) {} - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ----------------------- ----------- -------- - placeholder b b () {} - placeholder x x () {} - call_function mul (b, -1) {} - call_function mul_1 (x, mul) {} - output output output ((mul_1,),) {} - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ----------------------- --------- -------- - placeholder b b () {} - placeholder x x () {} - call_function mul (x, b) {} - output output output ((mul,),) {} - -The order of the last two graphs is nondeterministic depending -on which one is encountered first by the just-in-time compiler. + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------- ------------------------------------------------------ ---------------- -------- + placeholder a a () {} + placeholder b b () {} + call_function abs_1 (a,) {} + call_function add (abs_1, 1) {} + call_function truediv (a, add) {} + call_method sum_1 sum (b,) {} + call_function lt (sum_1, 0) {} + output output output ((truediv, lt),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- ----------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (b, -1) {} + call_function mul_1 (x, mul) {} + output output output ((mul_1,),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- --------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (x, b) {} + output output output ((mul,),) {} + + The order of the last two graphs is nondeterministic depending + on which one is encountered first by the just-in-time compiler. Speedy Backend ^^^^^^^^^^^^^^ @@ -249,17 +249,17 @@ with `optimize_for_inference