From 4026cef6238a0bbda6e1c92e51aad935c614954e Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 18 Oct 2024 22:25:41 +0200 Subject: [PATCH] Diagonal solver now works with differently input+output structures Honestly, it's a little suspicious whether we should even allow this: should being diagonal perhaps imply that the input and output structures are identical? Right now I'm choosing to allow this because it's pretty subtle, and users can apply their own diagonal tags, so if nothign else it's an easy mistake to be tolerant of. In particular, *we* were making this mistake by treating scalar operators as diagonal even when they had different structures. This was causing a downstream issue in Diffrax. --- lineax/_solver/diagonal.py | 44 ++++++++++++++++++++++++++------------ pyproject.toml | 2 +- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 24f1ac8..93f41bf 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -15,7 +15,6 @@ from typing import Any, Optional from typing_extensions import TypeAlias -import jax.flatten_util as jfu import jax.numpy as jnp from jaxtyping import Array, PyTree @@ -23,9 +22,16 @@ from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal from .._solution import RESULTS from .._solve import AbstractLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + transpose_packed_structures, + unravel_solution, +) -_DiagonalState: TypeAlias = Optional[Array] +_DiagonalState: TypeAlias = tuple[Optional[Array], PackedStructures] class Diagonal(AbstractLinearSolver[_DiagonalState], strict=True): @@ -52,39 +58,49 @@ def init( raise ValueError( "`Diagonal` may only be used for linear solves with diagonal matrices" ) + packed_structures = pack_structures(operator) if has_unit_diagonal(operator): - return None + return None, packed_structures else: - return diagonal(operator) + return diagonal(operator), packed_structures def compute( self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - diag = state + diag, packed_structures = state del state, options unit_diagonal = diag is None - # diagonal => symmetric => (in_structure == out_structure) => - # we don't need to use packed structures. + vector = ravel_vector(vector, packed_structures) if unit_diagonal: solution = vector else: - vector, unflatten = jfu.ravel_pytree(vector) if not self.well_posed: (size,) = diag.shape rcond = resolve_rcond(self.rcond, size, size, diag.dtype) abs_diag = jnp.abs(diag) diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) - solution = unflatten(vector / diag) + solution = vector / diag + solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _DiagonalState, options: dict[str, Any]): - # Matrix is symmetric - return state, options + del options + diag, packed_structures = state + transposed_packed_structures = transpose_packed_structures(packed_structures) + transpose_state = diag, transposed_packed_structures + transpose_options = {} + return transpose_state, transpose_options def conj(self, state: _DiagonalState, options: dict[str, Any]): - if state is None: - return None, options - return state.conj(), options + del options + diag, packed_structures = state + if diag is None: + conj_diag = None + else: + conj_diag = diag.conj() + conj_options = {} + conj_state = conj_diag, packed_structures + return conj_state, conj_options def allow_dependent_columns(self, operator): return not self.well_posed diff --git a/pyproject.toml b/pyproject.toml index b495cd7..c8c77c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "lineax" -version = "0.0.6" +version = "0.0.7" description = "Linear solvers in JAX and Equinox." readme = "README.md" requires-python ="~=3.9"