Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex-aware AD helpers broken for R -> C functions when has_aux = True #1934

Open
Matematija opened this issue Sep 26, 2024 · 2 comments
Open
Labels
contributor welcome We welcome contributions or PRs on this issue

Comments

@Matematija
Copy link

Hi - I want to report a simple bug with netket.jax.value_and_grad. The following errors out:

from netket.jax import value_and_grad

def f(x):
    return x**2 + 1j * x ** 3, 42
    
value_and_grad(f, has_aux=True)(1.0)

I believe that this is most likely caused by JAX changing the output convention of jax.value_and_grad from

value, grad, aux = jax.value_and_grad(f, has_aux=True)

to

(value, aux), grad = jax.value_and_grad(f, has_aux=True)

I don't know when this change happened or if you guys want/care to support the old convention at all. The error doesn't appear for R->R functions or C->C functions because in those cases the output of jax.value_and_grad is directly returned from nk.jax.value_and_grad . However, in the case of R->C functions, the output is unpacked manually here:

out_r, grad_r, aux = jax.value_and_grad(

This was tested for:

Netket version: 3.14.2
JAX version: 0.4.33
OS: Debian
Platform: CPU
@PhilipVinc
Copy link
Member

Thanks for the heads up.

However I would not rely on nk.jax.value_and_grad.
IIRC is not tested nor used anywhere, and more of an artefact from the past.
While the nk.jax.vjp is a well defined version of vjp with different conventions than those followed by jax, i never checked what this does well.

The only well defined gradient function around here would be nk.jax.jacobian.

I'll keep this up as a reminder that we should remove nkjax.grad and nkjax.value_and_grad

@PhilipVinc PhilipVinc added the contributor welcome We welcome contributions or PRs on this issue label Sep 26, 2024
@Matematija
Copy link
Author

Yeah, that was exactly my scenario - I had nkjax.value_and_grad in some old code that I ran with a newer version of JAX and it took me a hot minute to realize where the error was coming from. I agree with removing these function if they are unused and untested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor welcome We welcome contributions or PRs on this issue
Projects
None yet
Development

No branches or pull requests

2 participants