diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index 2ffc82a52..a437d1bbb 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -90,6 +90,7 @@ jobs: if: env.has_changes == 'true' run: | source .venv-${{ env.version }}/bin/activate + export PYTHONPATH=$(pwd) pytest -v --nbmake "./docs/notebooks" \ --nbmake-timeout=2000 \ --ignore=./docs/notebooks/zernike_eval.ipynb \ diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index f76033983..cda8c7ff2 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -255,7 +255,7 @@ def test_objective_jac_atf(benchmark): def run(x, objective): objective.jac_scaled_error(x, objective.constants).block_until_ready() - benchmark.pedantic(run, args=(x, objective), rounds=15, iterations=1) + benchmark.pedantic(run, args=(x, objective), rounds=20, iterations=1) @pytest.mark.slow @@ -288,7 +288,7 @@ def setup(): } return args, kwargs - benchmark.pedantic(perturb, setup=setup, rounds=10, iterations=1) + benchmark.pedantic(perturb, setup=setup, rounds=8, iterations=1) @pytest.mark.slow @@ -321,7 +321,7 @@ def setup(): } return args, kwargs - benchmark.pedantic(perturb, setup=setup, rounds=10, iterations=1) + benchmark.pedantic(perturb, setup=setup, rounds=8, iterations=1) @pytest.mark.slow @@ -341,7 +341,7 @@ def test_proximal_jac_atf(benchmark): def run(x, prox): prox.jac_scaled_error(x, prox.constants).block_until_ready() - benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1) + benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1) @pytest.mark.slow @@ -389,7 +389,7 @@ def test_proximal_freeb_jac(benchmark): def run(x, obj, prox): obj.jac_scaled_error(x, prox.constants).block_until_ready() - benchmark.pedantic(run, args=(x, obj, prox), rounds=15, iterations=1) + benchmark.pedantic(run, args=(x, obj, prox), rounds=10, iterations=1) @pytest.mark.slow @@ -398,6 +398,7 @@ def test_solve_fixed_iter_compiled(benchmark): """Benchmark running eq.solve for fixed iteration count after compilation.""" def setup(): + jax.clear_caches() eq = desc.examples.get("ESTELL") with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(6, 6, 6, 12, 12, 12) @@ -436,19 +437,18 @@ def run(eq): def test_LinearConstraintProjection_build(benchmark): """Benchmark LinearConstraintProjection build.""" eq = desc.examples.get("W7-X") - obj = ObjectiveFunction(ForceBalance(eq)) - con = get_fixed_boundary_constraints(eq) - con = maybe_add_self_consistency(eq, con) - con = ObjectiveFunction(con) - def run(obj, con): + def run(): jax.clear_caches() + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) lc = LinearConstraintProjection(obj, con) lc.build() benchmark.pedantic( run, - args=(obj, con), rounds=10, iterations=1, )