From aa9778d24465936b3fa868c5ac35222207f1d580 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Mon, 16 Sep 2024 10:20:21 -0700 Subject: [PATCH] [torch/numpy][numpy2.0 compat] Use np.exceptions.ComplexWarning if np.exceptions module exists in numpy_tests --- test/torch_np/numpy_tests/core/test_einsum.py | 2 -- test/torch_np/numpy_tests/core/test_indexing.py | 13 +++++++++---- test/torch_np/numpy_tests/core/test_multiarray.py | 13 +++++++++---- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/test/torch_np/numpy_tests/core/test_einsum.py b/test/torch_np/numpy_tests/core/test_einsum.py index 029e10cead4ab..5432fb63d1800 100644 --- a/test/torch_np/numpy_tests/core/test_einsum.py +++ b/test/torch_np/numpy_tests/core/test_einsum.py @@ -440,8 +440,6 @@ def check_einsum_sums(self, dtype, do_opt=False): # Suppress the complex warnings for the 'as f8' tests with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning) - # matvec(a,b) / a.dot(b) where a is matrix, b is vector for n in range(1, 17): a = np.arange(4 * n, dtype=dtype).reshape(4, n) diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index c5823dc30ff37..e88151ac689a0 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -602,16 +602,21 @@ def test_boolean_index_cast_assign(self): zero_array[bool_index] = np.array([1]) assert_equal(zero_array[0, 1], 1) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_mod = getattr(np, "exceptions", None) + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_mod else np.ComplexWarning + ) + # Fancy indexing works, although we get a cast warning. assert_warns( - np.ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) + ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) ) assert_equal(zero_array[0, 1], 2) # No complex part # Cast complex to float, throwing away the imaginary portion. - assert_warns( - np.ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j]) - ) + assert_warns(ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j])) assert_equal(zero_array[0, 1], 0) diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index baed25a9b021c..ac2ff87b7189f 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -5469,8 +5469,6 @@ def test_out_arg(self): out = np.zeros((5, 2), dtype=np.complex128) c = self.matmul(a, b, out=out) assert_(c is out) - # with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning, '') c = c.astype(tgt.dtype) assert_array_equal(c, tgt) @@ -5852,9 +5850,16 @@ def test_complex_warning(self): x = np.array([1, 2]) y = np.array([1 - 2j, 1 + 2j]) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_mod = getattr(np, "exceptions", None) + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_mod else np.ComplexWarning + ) + with warnings.catch_warnings(): - warnings.simplefilter("error", np.ComplexWarning) - assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y) + warnings.simplefilter("error", ComplexWarning) + assert_raises(ComplexWarning, x.__setitem__, slice(None), y) assert_equal(x, [1, 2])