How to properly use quimb with jax #219
-
Hi @jcmgray :) I'd like to know what are the best practices and suggested ways to use Consider the minimal example below of minimising the energy of a state. I have two main questions:
import jax
import numpy as np
import quimb.tensor as qtn
qtn.interface.jax_register_pytree()
# Expectation value
def loss_fn(state, observable):
s1, obs, s2 = qtn.tensor_network_align(state, observable, state.H)
return (s1 | obs | s2).contract()
# Initial random state and Pauli-Z observable on all sites
num_sites = 5
mps = qtn.MPS_rand_state(num_sites, 1)
mpo = qtn.MPO_product_operator(np.array([[[1,0],[0,-1]],]*num_sites))
# Jax-ify the functions in different ways
nojit_fn = jax.value_and_grad(loss_fn)
jit_fn1 = jax.jit(jax.value_and_grad(loss_fn))
jit_fn2 = jax.jit(jax.value_and_grad(loss_fn), static_argnums=[0, 1])
print(f"Initial loss = {loss_fn(mps, mpo)}\n")
# Run these so that they are compiled, if needed
nojit_fn(mps, mpo)
jit_fn1(mps, mpo)
jit_fn2(mps, mpo)
def optimize(update_function):
"""
Basic gradient descent update rule.
"""
psi = mps.copy()
for i in range(30):
psi = psi.multiply(1 / psi.norm(), spread_over='all')
val, grad = update_function(psi, mpo)
# Update parameters
new_params = jax.tree_map(lambda x, y: x - 0.1 * y, psi.get_params(), grad.get_params())
psi.set_params(new_params)
print(f"Step: {i} — Loss: {val}", end = "\r")
return psi.multiply(1 / psi.norm(), spread_over='all')
import time
print("No jit function")
start = time.time()
optimize(nojit_fn)
end = time.time()
print("\nExecution time [s]:", end - start)
print("")
print("Jit function")
start = time.time()
optimize(jit_fn1)
end = time.time()
print("\nExecution time [s]:", end - start)
print("")
print("Jit function with static_argnums")
start = time.time()
optimize(jit_fn2)
end = time.time()
print("\nExecution time [s]:", end - start) Here is the output: >>> Initial loss = 0.03858750129380334
>>> No jit function
>>> Step: 29 — Loss: -0.99956581221671596
>>> Execution time [s]: 0.40987300872802734
>>> Jit function
>>> Step: 29 — Loss: -0.99956581221671596
>>> Execution time [s]: 1.5457868576049805
>>> Jit function with static_argnums
>>> Step: 29 — Loss: -0.99956581221671596
>>> Execution time [s]: 0.7832067012786865 Note that if we don't redefine Similar findings are obtained by measuring runtimes using The code above is just a simple example to compare the performances of the single gradient computation instructions (with or without jitting). Of course, the proper way to speed-up the whole optimization would be to properly write and jit the whole training loop. Thank you so much! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I've explored a little bit deeper the issue, and I've seen that See the modified example here: import jax
import numpy as np
import quimb.tensor as qtn
# Expectation value
def loss_fn(params, skeleton):
state = qtn.unpack(params, skeleton)
s1, obs, s2 = qtn.tensor_network_align(state, observable, state.H)
return (s1 | obs | s2).contract()
# Initial random state and Pauli-Z observable on all sites
num_sites = 5
state = qtn.MPS_rand_state(num_sites, 1)
params, skeleton = qtn.pack(state)
observable = qtn.MPO_product_operator(np.array([[[1,0],[0,-1]],]*num_sites))
# Jax-ify the functions in different ways
nojit_fn = jax.value_and_grad(loss_fn)
jit_fn = jax.jit(jax.value_and_grad(loss_fn), static_argnums=[1])
print(f"Initial loss = {loss_fn(params, skeleton)}\n")
# Run these so that they are compiled, if needed
print(nojit_fn(params, skeleton))
print(jit_fn(params, skeleton))
print("")
def optimize(update_function):
"""
Basic gradient descent update rule.
"""
psi = state.copy()
params = psi.get_params()
for i in range(50):
# Normalize state
psi.set_params(params)
psi = psi.multiply(1 / psi.norm(), spread_over='all')
params = psi.get_params()
val, grad = update_function(params, skeleton)
# Update parameters
params = jax.tree_map(lambda x, y: x - 0.1 * y, params, grad)
print(f"Step: {i} — Loss: {val}", end = "\r")
return psi.multiply(1 / psi.norm(), spread_over='all')
import time
print("No jit function")
start = time.time()
optimize(nojit_fn)
end = time.time()
print("\nExecution time [s]:", end - start)
print("")
print("Jit function")
start = time.time()
optimize(jit_fn)
end = time.time()
print("\nExecution time [s]:", end - start)
print("") Output: >>> Initial loss = -0.02904531122882376
>>> (Array(-0.02904531, dtype=float32), {0: ..., ...})
>>> (Array(-0.02904531, dtype=float32), {0: ..., ...})
>>> No jit function
>>> Step: 49 — Loss: -1.09999964237213135
>>> Execution time [s]: 0.622642993927002
>>> Jit function
>>> Step: 49 — Loss: -1.09999964237213135
>>> Execution time [s]: 0.04969906806945801 I thus confirm my concerns that Given these observations, I would then conclude:
Do you agree? Sorry for the long and probably rambling questions, thank you so much in advance! |
Beta Was this translation helpful? Give feedback.
-
Hi @stfnmangini, sorry to be slow getting to this and thanks for the detailed examples! Indeed I get the same results when I run them. Yes high level the aim is allow both the "jax in quimb" (TNOptimizer) approach for simple things and the "quimb in jax" (where quimb just orchestrates various array operations) approach for detailed jax things. Certainly my understanding of the I think for the moment your conclusions are what I would also suggest:
|
Beta Was this translation helpful? Give feedback.
Hi @stfnmangini, sorry to be slow getting to this and thanks for the detailed examples! Indeed I get the same results when I run them.
Yes high level the aim is allow both the "jax in quimb" (TNOptimizer) approach for simple things and the "quimb in jax" (where quimb just orchestrates various array operations) approach for detailed jax things.
Certainly my understanding of the
jax_register_pytree
functionality was that it should enable jittable functions to accept/return quimb structures. However I have actually not looked much into this direction and so am not aware if this re-compilation thing is a bug or some misunderstanding of how pytrees work injax
- I can try and look into it but …