Skip to content

Commit

Permalink
[torch/numpy][numpy2.0 compat] Use np.exceptions.ComplexWarning if np…
Browse files Browse the repository at this point in the history
….exceptions module exists in numpy_tests
  • Loading branch information
kiukchung committed Sep 16, 2024
1 parent 5c4c662 commit aa9778d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
2 changes: 0 additions & 2 deletions test/torch_np/numpy_tests/core/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,6 @@ def check_einsum_sums(self, dtype, do_opt=False):

# Suppress the complex warnings for the 'as f8' tests
with suppress_warnings() as sup:
# sup.filter(np.ComplexWarning)

# matvec(a,b) / a.dot(b) where a is matrix, b is vector
for n in range(1, 17):
a = np.arange(4 * n, dtype=dtype).reshape(4, n)
Expand Down
13 changes: 9 additions & 4 deletions test/torch_np/numpy_tests/core/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,16 +602,21 @@ def test_boolean_index_cast_assign(self):
zero_array[bool_index] = np.array([1])
assert_equal(zero_array[0, 1], 1)

# np.ComplexWarning moved to np.exceptions in numpy>=2.0.0
# np.exceptions only available in numpy>=1.25.0
has_exceptions_mod = getattr(np, "exceptions", None)
ComplexWarning = (
np.exceptions.ComplexWarning if has_exceptions_mod else np.ComplexWarning
)

# Fancy indexing works, although we get a cast warning.
assert_warns(
np.ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j])
ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j])
)
assert_equal(zero_array[0, 1], 2) # No complex part

# Cast complex to float, throwing away the imaginary portion.
assert_warns(
np.ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j])
)
assert_warns(ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j]))
assert_equal(zero_array[0, 1], 0)


Expand Down
13 changes: 9 additions & 4 deletions test/torch_np/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5469,8 +5469,6 @@ def test_out_arg(self):
out = np.zeros((5, 2), dtype=np.complex128)
c = self.matmul(a, b, out=out)
assert_(c is out)
# with suppress_warnings() as sup:
# sup.filter(np.ComplexWarning, '')
c = c.astype(tgt.dtype)
assert_array_equal(c, tgt)

Expand Down Expand Up @@ -5852,9 +5850,16 @@ def test_complex_warning(self):
x = np.array([1, 2])
y = np.array([1 - 2j, 1 + 2j])

# np.ComplexWarning moved to np.exceptions in numpy>=2.0.0
# np.exceptions only available in numpy>=1.25.0
has_exceptions_mod = getattr(np, "exceptions", None)
ComplexWarning = (
np.exceptions.ComplexWarning if has_exceptions_mod else np.ComplexWarning
)

with warnings.catch_warnings():
warnings.simplefilter("error", np.ComplexWarning)
assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y)
warnings.simplefilter("error", ComplexWarning)
assert_raises(ComplexWarning, x.__setitem__, slice(None), y)
assert_equal(x, [1, 2])


Expand Down

0 comments on commit aa9778d

Please sign in to comment.