Skip to content

Commit

Permalink
Add benchmark for running fixed iteration solve (#1054)
Browse files Browse the repository at this point in the history
Should be merged before #1050 so benchmark will run.
  • Loading branch information
f0uriest authored Jun 14, 2024
2 parents 0a8ec62 + 9986152 commit 8c4051a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,17 @@ def run(x):
obj.jac_scaled(x, prox.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter(benchmark):
"""Benchmark running eq.solve for fixed iteration count."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
eq.change_resolution(6, 6, 6, 12, 12, 12)

def run(eq):
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1)
14 changes: 14 additions & 0 deletions tests/benchmarks/benchmark_gpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,17 @@ def run(x):
obj.jac_scaled(x, prox.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter(benchmark):
"""Benchmark running eq.solve for fixed iteration count."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
eq.change_resolution(6, 6, 6, 12, 12, 12)

def run(eq):
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1)

0 comments on commit 8c4051a

Please sign in to comment.