diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8bee5b9d..59678338 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -102,7 +102,7 @@ jobs: timeout-minutes: 60 strategy: matrix: - os: [ubuntu-latest, macos-latest] # jaxlib is not available on Windows + os: [ubuntu-latest, windows-latest, macos-latest] fail-fast: false steps: - name: Checkout diff --git a/CHANGELOG.md b/CHANGELOG.md index dd9bc4bb..4b2a2899 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140). - Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). ### Changed diff --git a/pyproject.toml b/pyproject.toml index 1dd131a6..eb3a19ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,9 +89,9 @@ test = [ 'pytest', 'pytest-cov', 'pytest-xdist', - 'jax[cpu] >= 0.3', - 'jaxopt', - 'optax', + 'jax[cpu] >= 0.3; platform_system != "Windows"', + 'jaxopt; platform_system != "Windows"', + 'optax; platform_system != "Windows"', ] [tool.setuptools.packages.find] diff --git a/tests/requirements.txt b/tests/requirements.txt index b0fa5e51..1d12cc79 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,9 +3,9 @@ torch >= 1.13 --requirement ../requirements.txt -jax[cpu] >= 0.3 -jaxopt -optax +jax[cpu] >= 0.3; platform_system != 'Windows' +jaxopt; platform_system != 'Windows' +optax; platform_system != 'Windows' pytest pytest-cov diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 9e3722d3..8672c588 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -20,11 +20,7 @@ from types import FunctionType import functorch -import jax -import jax.numpy as jnp -import jaxopt import numpy as np -import optax import pytest import torch import torch.nn as nn @@ -38,6 +34,18 @@ from torchopt.diff.implicit import ImplicitMetaGradientModule +try: + import jax + import jax.numpy as jnp + import jaxopt + import optax + + HAS_JAX = True +except ImportError: + jax = jnp = jaxopt = optax = None + HAS_JAX = False + + BATCH_SIZE = 8 NUM_UPDATES = 3 @@ -108,6 +116,7 @@ def get_rr_dataset_torch() -> data.DataLoader: return loader +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], @@ -234,6 +243,7 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(params, jax_params_as_tensor) +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], @@ -361,6 +371,7 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(params, jax_params_as_tensor) +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], @@ -472,6 +483,7 @@ def outer_level(p, xs, ys): helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor) +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4], @@ -574,6 +586,7 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq): helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor) +@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed') @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-3, 1e-4],