diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 24f1ac8..93f41bf 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -15,7 +15,6 @@ from typing import Any, Optional from typing_extensions import TypeAlias -import jax.flatten_util as jfu import jax.numpy as jnp from jaxtyping import Array, PyTree @@ -23,9 +22,16 @@ from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal from .._solution import RESULTS from .._solve import AbstractLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + transpose_packed_structures, + unravel_solution, +) -_DiagonalState: TypeAlias = Optional[Array] +_DiagonalState: TypeAlias = tuple[Optional[Array], PackedStructures] class Diagonal(AbstractLinearSolver[_DiagonalState], strict=True): @@ -52,39 +58,49 @@ def init( raise ValueError( "`Diagonal` may only be used for linear solves with diagonal matrices" ) + packed_structures = pack_structures(operator) if has_unit_diagonal(operator): - return None + return None, packed_structures else: - return diagonal(operator) + return diagonal(operator), packed_structures def compute( self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - diag = state + diag, packed_structures = state del state, options unit_diagonal = diag is None - # diagonal => symmetric => (in_structure == out_structure) => - # we don't need to use packed structures. + vector = ravel_vector(vector, packed_structures) if unit_diagonal: solution = vector else: - vector, unflatten = jfu.ravel_pytree(vector) if not self.well_posed: (size,) = diag.shape rcond = resolve_rcond(self.rcond, size, size, diag.dtype) abs_diag = jnp.abs(diag) diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) - solution = unflatten(vector / diag) + solution = vector / diag + solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _DiagonalState, options: dict[str, Any]): - # Matrix is symmetric - return state, options + del options + diag, packed_structures = state + transposed_packed_structures = transpose_packed_structures(packed_structures) + transpose_state = diag, transposed_packed_structures + transpose_options = {} + return transpose_state, transpose_options def conj(self, state: _DiagonalState, options: dict[str, Any]): - if state is None: - return None, options - return state.conj(), options + del options + diag, packed_structures = state + if diag is None: + conj_diag = None + else: + conj_diag = diag.conj() + conj_options = {} + conj_state = conj_diag, packed_structures + return conj_state, conj_options def allow_dependent_columns(self, operator): return not self.well_posed diff --git a/pyproject.toml b/pyproject.toml index b495cd7..c8c77c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "lineax" -version = "0.0.6" +version = "0.0.7" description = "Linear solvers in JAX and Equinox." readme = "README.md" requires-python ="~=3.9"