-
Notifications
You must be signed in to change notification settings - Fork 11
July 17 2024: Experimentation of Jax Jit version 2
A new method for implementing jax.jit
was to completely remove any jitted methods and to only implement it within functions outside of objects. To do this, the computation of NcalInv was remove from being a class property to a function outside the object. Note that the function get_NcalInv_RRF
only took in inputs, no objects.
@partial(jax.jit, static_argnames=['full_matrix', 'return_Gtilde_Ncal'])
def get_NcalInv_RRF(K_inv: jax.Array, G: jax.Array, CgwInv:jax.Array,
CirnInv:jax.Array, freqs:jax.Array, toas:jax.Array, full_matrix=False, return_Gtilde_Ncal=False):
"""_summary_
Args:
full_matrix (bool, optional): _description_. Defaults to False.
return_Gtilde_Ncal (bool, optional): _description_. Defaults to False.
Returns:,
_type_: _description_
"""
#Defining Ncal and NcalInv depending on existence of self.N or self.K_inv
nf = len(freqs)
N = len(toas)
T = toas.max()-toas.min()
#Fourier Design matrix
F = jnp.zeros((N, 2 * nf))
f = jnp.arange(1, nf + 1) / T
F = F.at[:, ::2].set(jnp.sin(2 * jnp.pi * toas[:, None] * f[None, :]))
F = F.at[:, 1::2].set(jnp.cos(2 * jnp.pi * toas[:, None] * f[None, :]))
del f
J = jnp.matmul(G.T, F)
del F
Z = jnp.matmul(J.T, K_inv)
Sigma = jnp.matmul(Z, J) + CirnInv + CgwInv
SigmaInv = jnp.linalg.inv(Sigma)
del J, Sigma
Gtilde = jnp.zeros((freqs.size,G.shape[1]),dtype='complex128')
Gtilde = jnp.dot(jnp.exp(1j*2*jnp.pi*freqs[:,jnp.newaxis]*toas),G)
NcalInv = (K_inv + jnp.matmul(Z.T, jnp.matmul(SigmaInv, Z))) / 2#divide by 2 for some reason
del SigmaInv, Z, K_inv
#divided by 4, not 2 for some reason, possibly some normalization stuff
TfN = jnp.matmul(jnp.conjugate(Gtilde),jnp.matmul(NcalInv,Gtilde.T)) / 2
if return_Gtilde_Ncal:
return jnp.real(TfN), Gtilde, jnp.linalg.inv(NcalInv)
elif full_matrix:
return jnp.real(TfN)
else:
return jnp.real(jnp.diag(TfN)) / T
The second modification of this implementation of jax.jit
was writing the tracers onto disk, specifically within a temporary file. Here is the code that is responsible for writing the traces:
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update('jax_persistent_cache_min_compile_time_secs', 0)
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
The following plots are the first execution of the modified code, before the tracers were written onto disk, along with profile data folder:
Jax Jit Tracer Creation Profile_Data.zip
The following plots are the second execution of the modified code, after the tracers were written onto disk, and the program has the ability to access the tracers on the disk, along with profile data folder and the tracer folder itself:
Jax Jit Tracer Usage Profile_Data_with_JaxCache.zip
As we can see, it is a insignificant change to save the tracers onto disk, both in terms of time and memory, but now lets look at just implementing no tracers and seeing if jitizing the function improves the efficiency in any way. To do this, I will comment out the commands responsible for saving the tracers to disk. Here are the results:
Jax Jit no Tracer Profile_Data_jit_notrace.zip
Last, lets compare it with the baseline, no jit. To do this, I will comment out the partial decorator. Here are the results:
No Jax Jit Profile_Data_nojaxjit.zip
The only speedups that jax will have will be in terms of GPU usage within its module jax.numpy
. The use of jax.jit
is not only not beneficial, but increases memory usage due to the additional of XLA compilation and addition of tracers.