From 6cf00ce789dfbcdec7898285a06b96e60a04ce71 Mon Sep 17 00:00:00 2001 From: Alvin Sun Date: Sat, 4 May 2024 20:09:48 -0700 Subject: [PATCH] support `has_aux` for manifold gradient functions (#17) * support has_aux for manifold gradient functions * Formatting, mypy --------- Co-authored-by: Brent Yi --- jaxlie/_base.py | 6 ++---- jaxlie/manifold/_backprop.py | 25 ++++++++++++++----------- jaxlie/manifold/_deltas.py | 12 ++++-------- tests/test_examples.py | 1 + tests/test_manifold.py | 1 + tests/utils.py | 5 ++--- 6 files changed, 24 insertions(+), 26 deletions(-) diff --git a/jaxlie/_base.py b/jaxlie/_base.py index 4453a92..973ca65 100644 --- a/jaxlie/_base.py +++ b/jaxlie/_base.py @@ -44,12 +44,10 @@ def __init__( # Shared implementations. @overload - def __matmul__(self: GroupType, other: GroupType) -> GroupType: - ... + def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... @overload - def __matmul__(self, other: hints.Array) -> jax.Array: - ... + def __matmul__(self, other: hints.Array) -> jax.Array: ... def __matmul__( self: GroupType, other: Union[GroupType, hints.Array] diff --git a/jaxlie/manifold/_backprop.py b/jaxlie/manifold/_backprop.py index 7f1c318..0564a89 100644 --- a/jaxlie/manifold/_backprop.py +++ b/jaxlie/manifold/_backprop.py @@ -37,8 +37,7 @@ def grad( holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), -) -> Callable[P, _tree_utils.TangentPytree]: - ... +) -> Callable[P, _tree_utils.TangentPytree]: ... @overload @@ -49,8 +48,7 @@ def grad( holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), -) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: - ... +) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: ... def grad( @@ -72,7 +70,15 @@ def grad( allow_int=allow_int, reduce_axes=reduce_axes, ) - return lambda *args, **kwargs: compute_value_and_grad(*args, **kwargs)[1] # type: ignore + + def grad_fun(*args, **kwargs): + ret = compute_value_and_grad(*args, **kwargs) + if has_aux: + return ret[1], ret[0][1] + else: + return ret[1] + + return grad_fun @overload @@ -83,8 +89,7 @@ def value_and_grad( holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), -) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: - ... +) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: ... @overload @@ -95,8 +100,7 @@ def value_and_grad( holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), -) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: - ... +) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: ... def value_and_grad( @@ -121,7 +125,7 @@ def tangent_fun(*tangent_args, **tangent_kwargs): tangent_args = map(zero_tangents, args) tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()} - value, grad = jax.value_and_grad( + return jax.value_and_grad( fun=tangent_fun, argnums=argnums, has_aux=has_aux, @@ -129,6 +133,5 @@ def tangent_fun(*tangent_args, **tangent_kwargs): allow_int=allow_int, reduce_axes=reduce_axes, )(*tangent_args, **tangent_kwargs) - return value, grad return wrapped_grad # type: ignore diff --git a/jaxlie/manifold/_deltas.py b/jaxlie/manifold/_deltas.py index b2217d9..7b8b6c9 100644 --- a/jaxlie/manifold/_deltas.py +++ b/jaxlie/manifold/_deltas.py @@ -49,16 +49,14 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType: def rplus( transform: GroupType, delta: hints.Array, -) -> GroupType: - ... +) -> GroupType: ... @overload def rplus( transform: PytreeType, delta: _tree_utils.TangentPytree, -) -> PytreeType: - ... +) -> PytreeType: ... # Using our typevars in the overloaded signature will cause errors. @@ -81,13 +79,11 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array: @overload -def rminus(a: GroupType, b: GroupType) -> jax.Array: - ... +def rminus(a: GroupType, b: GroupType) -> jax.Array: ... @overload -def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: - ... +def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... # Using our typevars in the overloaded signature will cause errors. diff --git a/tests/test_examples.py b/tests/test_examples.py index d3eb234..67b7af0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,4 +1,5 @@ """Tests with explicit examples.""" + import numpy as onp import pytest from hypothesis import given, settings diff --git a/tests/test_manifold.py b/tests/test_manifold.py index 9ef57cd..ca31f30 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -1,4 +1,5 @@ """Test manifold helpers.""" + from typing import Type import jax diff --git a/tests/utils.py b/tests/utils.py index 8da2c27..55d2780 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,12 +9,11 @@ from hypothesis import given, settings from hypothesis import strategies as st from jax import numpy as jnp -from jax.config import config import jaxlie # Run all tests with double-precision. -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) T = TypeVar("T", bound=jaxlie.MatrixLieGroup) @@ -101,7 +100,7 @@ def assert_arrays_close( def jacnumerical( - f: Callable[[jaxlie.hints.Array], jax.Array] + f: Callable[[jaxlie.hints.Array], jax.Array], ) -> Callable[[jaxlie.hints.Array], jax.Array]: """Decorator for computing numerical Jacobians of vector->vector functions."""