From 75bb049d38e22ec9352eb0cfa5269f05eff75e8f Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Fri, 8 Mar 2024 10:50:59 -0800 Subject: [PATCH] Skip AOT Inductor test_cond_* tests on ROCm (#121522) Summary: The newly added tests in https://github.com/pytorch/pytorch/pull/121120 are failing in the `ciflow/periodic` jobs. Here we skip those on ROCm to avoid the need to disable those tests manually on ROCm. Test Plan: ``` $ python test/inductor/test_aot_inductor.py -k test_cond_nested ... ---------------------------------------------------------------------- Ran 6 tests in 72.122s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121522 Approved by: https://github.com/huydhn, https://github.com/malfet ghstack dependencies: #121120 --- test/inductor/test_aot_inductor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ffac33eb31bb2c..d347e4de9593c2 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -757,6 +757,7 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) + @skipIfRocm def test_cond_simple(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -774,6 +775,7 @@ def test_cond_simple(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_nested(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -795,6 +797,7 @@ def test_cond_nested(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_with_parameters(self): inputs = (torch.randn((10, 20), device=self.device),) dim0_abc = Dim("s0", min=2, max=1024) @@ -808,6 +811,7 @@ def test_cond_with_parameters(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_with_reinterpret_view_inputs_outputs(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -825,6 +829,7 @@ def test_cond_with_reinterpret_view_inputs_outputs(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_with_multiple_outputs(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -845,6 +850,7 @@ def test_cond_with_multiple_outputs(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_with_outer_code_before_after(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -862,6 +868,7 @@ def test_cond_with_outer_code_before_after(self): dynamic_shapes=dynamic_shapes, ) + @skipIfRocm def test_cond_use_buffers_from_outer_scope(self): inputs = ( torch.randn((10, 20), device=self.device),