-
Notifications
You must be signed in to change notification settings - Fork 26
54 lines (52 loc) · 1.79 KB
/
jax_tests.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
name: Dependency test JAX
on:
pull_request:
types: [labeled]
workflow_dispatch:
jobs:
jax_tests:
if: ${{ github.event.label.name == 'test_jax' && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
jax-version: [0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17,
0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23,
0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29,
0.4.30, 0.4.31, 0.4.33, 0.4.34, 0.4.35, 0.4.37]
# 0.4.32 is not available on PyPI
# earlier jax versions are not compatible with other
# dependencies as of 2024-10-04
# 0.4.36 has a bug that causes tests to fail
group: [1, 2]
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
- name: Install dependencies with given JAX version
run: |
sed -i '/jax/d' ./requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
- name: Verify dependencies
run: |
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
lscpu
python -m pytest -m unit \
--durations=0 \
--mpl \
--maxfail=1 \
--splits 3 \
--group ${{ matrix.group }} \
--splitting-algorithm least_duration