Complex-aware AD helpers broken for R -> C functions when has_aux = True
#1934
Labels
contributor welcome
We welcome contributions or PRs on this issue
has_aux = True
#1934
Hi - I want to report a simple bug with
netket.jax.value_and_grad
. The following errors out:I believe that this is most likely caused by JAX changing the output convention of
jax.value_and_grad
fromto
I don't know when this change happened or if you guys want/care to support the old convention at all. The error doesn't appear for R->R functions or C->C functions because in those cases the output of
jax.value_and_grad
is directly returned fromnk.jax.value_and_grad
. However, in the case of R->C functions, the output is unpacked manually here:netket/netket/jax/_grad.py
Line 167 in d444d83
This was tested for:
The text was updated successfully, but these errors were encountered: