diff --git a/test/test_nn.py b/test/test_nn.py index c8311c91d73ac..db086efa1cbd5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8421,7 +8421,6 @@ def test_cudnn_rnn_dropout_states_device(self): output = rnn(input, hx) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), @@ -8460,8 +8459,9 @@ def test_cudnn_weight_format(self): with warnings.catch_warnings(record=True) as w: output_noncontig = rnn(input, hx) if first_warn: - self.assertEqual(len(w), 1) - self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0]) + if (torch.version.hip is None): + self.assertEqual(len(w), 1) + self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0]) first_warn = False warnings.resetwarnings() output_noncontig[0].sum().backward()