Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support has_aux for manifold gradient functions #17

Merged
merged 2 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def __init__(
# Shared implementations.

@overload
def __matmul__(self: GroupType, other: GroupType) -> GroupType:
...
def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...

@overload
def __matmul__(self, other: hints.Array) -> jax.Array:
...
def __matmul__(self, other: hints.Array) -> jax.Array: ...

def __matmul__(
self: GroupType, other: Union[GroupType, hints.Array]
Expand Down
25 changes: 14 additions & 11 deletions jaxlie/manifold/_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, _tree_utils.TangentPytree]:
...
) -> Callable[P, _tree_utils.TangentPytree]: ...


@overload
Expand All @@ -49,8 +48,7 @@ def grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]:
...
) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: ...


def grad(
Expand All @@ -72,7 +70,15 @@ def grad(
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return lambda *args, **kwargs: compute_value_and_grad(*args, **kwargs)[1] # type: ignore

def grad_fun(*args, **kwargs):
ret = compute_value_and_grad(*args, **kwargs)
if has_aux:
return ret[1], ret[0][1]
else:
return ret[1]

return grad_fun


@overload
Expand All @@ -83,8 +89,7 @@ def value_and_grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]:
...
) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: ...


@overload
Expand All @@ -95,8 +100,7 @@ def value_and_grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]:
...
) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: ...


def value_and_grad(
Expand All @@ -121,14 +125,13 @@ def tangent_fun(*tangent_args, **tangent_kwargs):
tangent_args = map(zero_tangents, args)
tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()}

value, grad = jax.value_and_grad(
return jax.value_and_grad(
fun=tangent_fun,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*tangent_args, **tangent_kwargs)
return value, grad

return wrapped_grad # type: ignore
12 changes: 4 additions & 8 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,14 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
def rplus(
transform: GroupType,
delta: hints.Array,
) -> GroupType:
...
) -> GroupType: ...


@overload
def rplus(
transform: PytreeType,
delta: _tree_utils.TangentPytree,
) -> PytreeType:
...
) -> PytreeType: ...


# Using our typevars in the overloaded signature will cause errors.
Expand All @@ -81,13 +79,11 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array:


@overload
def rminus(a: GroupType, b: GroupType) -> jax.Array:
...
def rminus(a: GroupType, b: GroupType) -> jax.Array: ...


@overload
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree:
...
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ...


# Using our typevars in the overloaded signature will cause errors.
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests with explicit examples."""

import numpy as onp
import pytest
from hypothesis import given, settings
Expand Down
1 change: 1 addition & 0 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test manifold helpers."""

from typing import Type

import jax
Expand Down
5 changes: 2 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp
from jax.config import config

import jaxlie

# Run all tests with double-precision.
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

T = TypeVar("T", bound=jaxlie.MatrixLieGroup)

Expand Down Expand Up @@ -101,7 +100,7 @@ def assert_arrays_close(


def jacnumerical(
f: Callable[[jaxlie.hints.Array], jax.Array]
f: Callable[[jaxlie.hints.Array], jax.Array],
) -> Callable[[jaxlie.hints.Array], jax.Array]:
"""Decorator for computing numerical Jacobians of vector->vector functions."""

Expand Down
Loading