Skip to content

Commit

Permalink
Adds a raw torch test
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jun 26, 2024
1 parent 8658505 commit 5db6252
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 29 deletions.
4 changes: 3 additions & 1 deletion devtools/conda-envs/torch-environment.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name: test
channels:
- pytorch
- conda-forge
dependencies:
# Base depends
- python >=3.9
- pytorch-cpu >=2.0,<3.0.0a
- pytorch::pytorch-cpu >=2.0,<3.0.0a
- pytorch::cpuonly

# Testing
- autoflake
Expand Down
22 changes: 3 additions & 19 deletions opt_einsum/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Testing routines for opt_einsum.
"""

from importlib import import_module
from importlib.util import find_spec
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload

import pytest
Expand All @@ -15,20 +13,6 @@
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}

HAS_NUMPY = find_spec("numpy") is not None

using_numpy = pytest.mark.skipif(
not HAS_NUMPY,
reason="Numpy not detected.",
)


def import_numpy_or_skip() -> Any:
if not HAS_NUMPY:
pytest.skip("Numpy not detected.")
else:
return import_module("numpy")


def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]:
"""
Expand Down Expand Up @@ -80,7 +64,7 @@ def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) ->
```
"""
np = import_numpy_or_skip()
np = pytest.importorskip("numpy")
views = []
for shape in build_shapes(string, dimension_dict=dimension_dict):
views.append(np.random.rand(*shape))
Expand Down Expand Up @@ -163,7 +147,7 @@ def rand_equation(
```
"""

np = import_numpy_or_skip()
np = pytest.importorskip("numpy")
if seed is not None:
np.random.seed(seed)

Expand Down Expand Up @@ -230,6 +214,6 @@ def build_arrays_from_tuples(path: PathType) -> List[Any]:
Returns:
The resulting arrays."""
np = import_numpy_or_skip()
np = pytest.importorskip("numpy")

return [np.random.rand(*x) for x in path]
35 changes: 28 additions & 7 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from opt_einsum.contract import ArrayShaped, infer_backend, parse_backend
from opt_einsum.testing import build_views

# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")

try:
import cupy

Expand Down Expand Up @@ -74,9 +71,17 @@
]


def test_torch_raw() -> None:
"""Tests torch in the abscence of any other dependancy."""
torch = pytest.importorskip("torch")
result = contract("ij,jk->ik", torch.rand(6, 5), torch.rand(5, 4))
assert result.shape == (6, 4)


@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = np.empty_like(ein)
Expand All @@ -99,6 +104,7 @@ def test_tensorflow(string: str) -> None:
@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_tensorflow_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
Expand Down Expand Up @@ -128,6 +134,7 @@ def test_tensorflow_with_constants(constants: Set[int]) -> None:
@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

Expand All @@ -153,6 +160,7 @@ def test_tensorflow_with_sharing(string: str) -> None:
@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
Expand All @@ -171,6 +179,7 @@ def test_theano(string: str) -> None:
@pytest.mark.skipif(not found_theano, reason="theano not installed.")
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_theano_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
Expand All @@ -197,6 +206,7 @@ def test_theano_with_constants(constants: Set[int]) -> None:
@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

Expand All @@ -219,7 +229,8 @@ def test_theano_with_sharing(string: str) -> None:

@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.")
@pytest.mark.parametrize("string", tests)
def test_cupy(string: str) -> None: # pragma: no cover
def test_cupy(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
Expand All @@ -238,7 +249,8 @@ def test_cupy(string: str) -> None: # pragma: no cover

@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.")
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover
def test_cupy_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
Expand Down Expand Up @@ -266,7 +278,8 @@ def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover

@pytest.mark.skipif(not found_jax, reason="jax not installed.")
@pytest.mark.parametrize("string", tests)
def test_jax(string: str) -> None: # pragma: no cover
def test_jax(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
Expand All @@ -280,7 +293,8 @@ def test_jax(string: str) -> None: # pragma: no cover

@pytest.mark.skipif(not found_jax, reason="jax not installed.")
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_jax_with_constants(constants: Set[int]) -> None: # pragma: no cover
def test_jax_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
Expand All @@ -300,6 +314,7 @@ def test_jax_with_constants(constants: Set[int]) -> None: # pragma: no cover

@pytest.mark.skipif(not found_jax, reason="jax not installed.")
def test_jax_jit_gradient() -> None:
np = pytest.importorskip("numpy")
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [np.random.randn(*s) for s in shapes]
Expand All @@ -323,6 +338,7 @@ def test_jax_jit_gradient() -> None:

@pytest.mark.skipif(not found_autograd, reason="autograd not installed.")
def test_autograd_gradient() -> None:
np = pytest.importorskip("numpy")
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [np.random.randn(*s) for s in shapes]
Expand All @@ -342,6 +358,7 @@ def test_autograd_gradient() -> None:

@pytest.mark.parametrize("string", tests)
def test_dask(string: str) -> None:
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

views = build_views(string)
Expand All @@ -366,6 +383,7 @@ def test_dask(string: str) -> None:

@pytest.mark.parametrize("string", tests)
def test_sparse(string: str) -> None:
np = pytest.importorskip("numpy")
sparse = pytest.importorskip("sparse")

views = build_views(string)
Expand Down Expand Up @@ -402,6 +420,7 @@ def test_sparse(string: str) -> None:
@pytest.mark.skipif(not found_torch, reason="Torch not installed.")
@pytest.mark.parametrize("string", tests)
def test_torch(string: str) -> None:
np = pytest.importorskip("numpy")

views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
Expand All @@ -422,6 +441,7 @@ def test_torch(string: str) -> None:
@pytest.mark.skipif(not found_torch, reason="Torch not installed.")
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_torch_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
Expand Down Expand Up @@ -457,6 +477,7 @@ def test_auto_backend_custom_array_no_tensordot() -> None:

@pytest.mark.parametrize("string", tests)
def test_object_arrays_backend(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
assert ein.dtype != object
Expand Down
4 changes: 2 additions & 2 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

import opt_einsum as oe
from opt_einsum.testing import build_shapes, rand_equation, using_numpy
from opt_einsum.testing import build_shapes, rand_equation
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType

explicit_path_tests = {
Expand Down Expand Up @@ -131,8 +131,8 @@ def test_bad_path_option() -> None:
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore


@using_numpy
def test_explicit_path() -> None:
pytest.importorskip("numpy")
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
assert x.item() == 6

Expand Down

0 comments on commit 5db6252

Please sign in to comment.