Optimizing a variable loss function #233
-
I have an ansatz circuit C(p) parametrized by some parameters p, and a loss function L(C, q) that takes in the circuit C and some auxiliary variable q. I have a long list of values [q[1], ..., q[n]], and for each q[i] I wish to optimize L(C(p), q[i]) over the parameters p. Using TNOptimizer with jax backend, the cost of this optimization is dominated by jit compilation of the gradient. My question is: is there any way to avoid compiling the gradient for each q[i]? In principle, it should be possible to have a single gradient function that takes q as input. The trouble is, if I feed in q as a constant to TNOptimizer, the compiled gradient (seemingly) will not vary at all if I go back and modify q. However, I also don't want to feed in q as a parameter to be optimized, because in any given iteration I'm fixing q to q[i] and only optimizing over p. In case it's still unclear what I mean, below an example based on https://quimb.readthedocs.io/en/latest/examples/ex_tn_train_circuit.html. The Hamiltonian H changes in each iteration, but the ansatz circuit is always the same. Can I restructure this code so that the gradient compilation only happens once, instead of once per iteration?
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi @wkretschmer. No I think sadly this kind of 'non-optimized variable' is not supported at the moment. It would be a nice feature to have, especially for jax, which indeed has very slow compilation. My suggestions would be:
|
Beta Was this translation helpful? Give feedback.
Hi @wkretschmer. No I think sadly this kind of 'non-optimized variable' is not supported at the moment. It would be a nice feature to have, especially for jax, which indeed has very slow compilation.
My suggestions would be:
autodiff_backend
or try turningjit_fn=False
- might be slow however.TNOptimizer
- see e.g. https://quimb.readthedocs.io/en/latest/examples/ex_quimb_within_jax_flax_optax.html - maybe with theqtn.pack
/qtn.unpack
functionsloss_variables
'…