diff --git a/test/export/test_export.py b/test/export/test_export.py index 35ae69a0a838e..43b3a760afcd0 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3128,6 +3128,21 @@ def forward(self, image, crop_height, crop_width): args = (torch.rand(3, 700, 700), 150, 150) self.assertEqual(ecrop.module()(*args), ecrop(*args)) + def test_dim_dynamic_divisibility(self): + class M(torch.nn.Module): + def forward(self, x): + if x.size(0) % 2 == 0: + return x.clone() * 2 + else: + return x.clone() * 0 + + input1 = (torch.randn(4),) + model = M() + dynamic_shapes = { + "x": {0: torch.export.Dim.DYNAMIC}, + } + export(model, input1, dynamic_shapes=dynamic_shapes) + def test_export_func_with_kwargs(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, kw1, kw2): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 13b7c828e9b91..f4653e0356430 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2436,7 +2436,12 @@ def solve(self) -> None: ): if self._is_supported_congruence(congruence): base, divisor = congruence.args - tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}" + tmp_name = "_" + str( + self._dcp.source_name_to_debug_name.get( + self._dcp.symbol_to_source[s][0].name(), + self._dcp.symbol_to_source[s][0].name(), + ) + ) tmp = sympy.Symbol(tmp_name, integer=True) from torch._dynamo.source import ConstantSource