Skip to content

Commit

Permalink
Fixed a crash with symbolic zero cotangents; fixed a crash with pytre…
Browse files Browse the repository at this point in the history
…es+CG
  • Loading branch information
patrick-kidger committed Oct 10, 2023
1 parent 8b63370 commit e0ad696
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _keep_undefined(v, ct):
return None


@eqxi.filter_primitive_transpose
@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
def _linear_solve_transpose(inputs, cts_out):
cts_solution, _, _ = cts_out
operator, state, vector, options, solver, _ = inputs
Expand Down
4 changes: 2 additions & 2 deletions lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def mv(vector: PyTree) -> PyTree:
mv = operator.mv
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
leaves, _ = jtu.tree_flatten(vector)
size = sum(leaf.size for leaf in leaves)
if self.max_steps is None:
size = sum(leaf.size for leaf in leaves)
max_steps = 10 * size # Copied from SciPy!
else:
max_steps = self.max_steps
r0 = (vector**ω - mv(y0) ** ω).ω
p0 = preconditioner.mv(r0)
gamma0 = tree_dot(r0, p0)
rcond = resolve_rcond(None, vector.size, vector.size, vector.dtype)
rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))
initial_value = (
ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,
y0,
Expand Down

0 comments on commit e0ad696

Please sign in to comment.