diff --git a/lineax/_solve.py b/lineax/_solve.py index c339d5d..9b246f8 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -256,7 +256,7 @@ def _keep_undefined(v, ct): return None -@eqxi.filter_primitive_transpose +@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _linear_solve_transpose(inputs, cts_out): cts_solution, _, _ = cts_out operator, state, vector, options, solver, _ = inputs diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 573737a..b536f2f 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -124,15 +124,15 @@ def mv(vector: PyTree) -> PyTree: mv = operator.mv preconditioner, y0 = preconditioner_and_y0(operator, vector, options) leaves, _ = jtu.tree_flatten(vector) + size = sum(leaf.size for leaf in leaves) if self.max_steps is None: - size = sum(leaf.size for leaf in leaves) max_steps = 10 * size # Copied from SciPy! else: max_steps = self.max_steps r0 = (vector**ω - mv(y0) ** ω).ω p0 = preconditioner.mv(r0) gamma0 = tree_dot(r0, p0) - rcond = resolve_rcond(None, vector.size, vector.size, vector.dtype) + rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves)) initial_value = ( ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, y0,