Skip to content

Commit

Permalink
test(workflows): enable tests on Windows (metaopt#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Feb 20, 2023
1 parent 6fa85d6 commit 14962cf
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ jobs:
timeout-minutes: 60
strategy:
matrix:
os: [ubuntu-latest, macos-latest] # jaxlib is not available on Windows
os: [ubuntu-latest, windows-latest, macos-latest]
fail-fast: false
steps:
- name: Checkout
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140).
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139).

### Changed
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ test = [
'pytest',
'pytest-cov',
'pytest-xdist',
'jax[cpu] >= 0.3',
'jaxopt',
'optax',
'jax[cpu] >= 0.3; platform_system != "Windows"',
'jaxopt; platform_system != "Windows"',
'optax; platform_system != "Windows"',
]

[tool.setuptools.packages.find]
Expand Down
6 changes: 3 additions & 3 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ torch >= 1.13

--requirement ../requirements.txt

jax[cpu] >= 0.3
jaxopt
optax
jax[cpu] >= 0.3; platform_system != 'Windows'
jaxopt; platform_system != 'Windows'
optax; platform_system != 'Windows'

pytest
pytest-cov
Expand Down
21 changes: 17 additions & 4 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from types import FunctionType

import functorch
import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
import optax
import pytest
import torch
import torch.nn as nn
Expand All @@ -38,6 +34,18 @@
from torchopt.diff.implicit import ImplicitMetaGradientModule


try:
import jax
import jax.numpy as jnp
import jaxopt
import optax

HAS_JAX = True
except ImportError:
jax = jnp = jaxopt = optax = None
HAS_JAX = False


BATCH_SIZE = 8
NUM_UPDATES = 3

Expand Down Expand Up @@ -108,6 +116,7 @@ def get_rr_dataset_torch() -> data.DataLoader:
return loader


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -234,6 +243,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(params, jax_params_as_tensor)


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -361,6 +371,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(params, jax_params_as_tensor)


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -472,6 +483,7 @@ def outer_level(p, xs, ys):
helpers.assert_pytree_all_close(tuple(model.parameters()), jax_params_as_tensor)


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down Expand Up @@ -574,6 +586,7 @@ def outer_level(params_jax, l2reg_jax, xs, ys, xq, yq):
helpers.assert_all_close(l2reg_torch, l2reg_jax_as_tensor)


@pytest.mark.skipif(not HAS_JAX, reason='JAX is not installed')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-3, 1e-4],
Expand Down

0 comments on commit 14962cf

Please sign in to comment.