From 180d283156ca89671d11abc900c53878f4137982 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 30 Oct 2024 18:12:43 +0000 Subject: [PATCH] [export] avoid debug name crash for dim hints (#139104) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139104 Approved by: https://github.com/ezyang --- test/export/test_export.py | 15 +++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 7 ++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 35ae69a0a838eb..43b3a760afcd0c 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3128,6 +3128,21 @@ def forward(self, image, crop_height, crop_width): args = (torch.rand(3, 700, 700), 150, 150) self.assertEqual(ecrop.module()(*args), ecrop(*args)) + def test_dim_dynamic_divisibility(self): + class M(torch.nn.Module): + def forward(self, x): + if x.size(0) % 2 == 0: + return x.clone() * 2 + else: + return x.clone() * 0 + + input1 = (torch.randn(4),) + model = M() + dynamic_shapes = { + "x": {0: torch.export.Dim.DYNAMIC}, + } + export(model, input1, dynamic_shapes=dynamic_shapes) + def test_export_func_with_kwargs(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, kw1, kw2): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 13b7c828e9b915..f4653e03564309 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2436,7 +2436,12 @@ def solve(self) -> None: ): if self._is_supported_congruence(congruence): base, divisor = congruence.args - tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}" + tmp_name = "_" + str( + self._dcp.source_name_to_debug_name.get( + self._dcp.symbol_to_source[s][0].name(), + self._dcp.symbol_to_source[s][0].name(), + ) + ) tmp = sympy.Symbol(tmp_name, integer=True) from torch._dynamo.source import ConstantSource