From e0ad696a66f9909661e1bac69669ae78cdf73845 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 10 Oct 2023 13:35:33 -0700 Subject: [PATCH] Fixed a crash with symbolic zero cotangents; fixed a crash with pytrees+CG --- lineax/_solve.py | 2 +- lineax/_solver/cg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lineax/_solve.py b/lineax/_solve.py index c339d5d..9b246f8 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -256,7 +256,7 @@ def _keep_undefined(v, ct): return None -@eqxi.filter_primitive_transpose +@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _linear_solve_transpose(inputs, cts_out): cts_solution, _, _ = cts_out operator, state, vector, options, solver, _ = inputs diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 573737a..b536f2f 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -124,15 +124,15 @@ def mv(vector: PyTree) -> PyTree: mv = operator.mv preconditioner, y0 = preconditioner_and_y0(operator, vector, options) leaves, _ = jtu.tree_flatten(vector) + size = sum(leaf.size for leaf in leaves) if self.max_steps is None: - size = sum(leaf.size for leaf in leaves) max_steps = 10 * size # Copied from SciPy! else: max_steps = self.max_steps r0 = (vector**ω - mv(y0) ** ω).ω p0 = preconditioner.mv(r0) gamma0 = tree_dot(r0, p0) - rcond = resolve_rcond(None, vector.size, vector.size, vector.dtype) + rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves)) initial_value = ( ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, y0,