Skip to content

July 17 2024: Experimentation of Jax Jit version 2

GourlieK edited this page Jul 17, 2024 · 9 revisions

A new method for implementing jax.jit was to completely remove any jitted methods and to only implement it within functions outside of objects. To do this, the computation of NcalInv was remove from being a class property to a function outside the object. Note that the function get_NcalInv_RRF only took in inputs, no objects.

@partial(jax.jit, static_argnames=['full_matrix', 'return_Gtilde_Ncal'])
def get_NcalInv_RRF(K_inv: jax.Array, G: jax.Array, CgwInv:jax.Array, 
                    CirnInv:jax.Array, freqs:jax.Array, toas:jax.Array,   full_matrix=False, return_Gtilde_Ncal=False):

        full_matrix (bool, optional): _description_. Defaults to False.
        return_Gtilde_Ncal (bool, optional): _description_. Defaults to False.

        _type_: _description_
    #Defining Ncal and NcalInv depending on existence of self.N or self.K_inv
    nf = len(freqs)
    N = len(toas)
    T = toas.max()-toas.min()
    #Fourier Design matrix
    F  = jnp.zeros((N, 2 * nf))
    f = jnp.arange(1, nf + 1) / T
    F =[:, ::2].set(jnp.sin(2 * jnp.pi * toas[:, None] * f[None, :]))
    F =[:, 1::2].set(jnp.cos(2 * jnp.pi * toas[:, None] * f[None, :]))
    del f   
    J = jnp.matmul(G.T, F)
    del F
    Z = jnp.matmul(J.T, K_inv)
    Sigma = jnp.matmul(Z, J) + CirnInv + CgwInv
    SigmaInv = jnp.linalg.inv(Sigma)

    del J, Sigma
    Gtilde = jnp.zeros((freqs.size,G.shape[1]),dtype='complex128')
    Gtilde =*2*jnp.pi*freqs[:,jnp.newaxis]*toas),G)

    NcalInv = (K_inv + jnp.matmul(Z.T, jnp.matmul(SigmaInv, Z))) / 2#divide by 2 for some reason   

    del SigmaInv, Z, K_inv
    #divided by 4, not 2 for some reason, possibly some normalization stuff
    TfN = jnp.matmul(jnp.conjugate(Gtilde),jnp.matmul(NcalInv,Gtilde.T)) / 2

    if return_Gtilde_Ncal:
        return jnp.real(TfN), Gtilde, jnp.linalg.inv(NcalInv)
    elif full_matrix:
        return jnp.real(TfN)
        return jnp.real(jnp.diag(TfN)) / T

The second modification of this implementation of jax.jit was writing the tracers onto disk, specifically within a temporary file. Here is the code that is responsible for writing the traces:

os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update('jax_persistent_cache_min_compile_time_secs', 0)
from jax.experimental.compilation_cache import compilation_cache as cc

The following plots are the first execution of the modified code, before the tracers were written onto disk, along with profile data folder:

Jax Jit Tracer Creation Screenshot from 2024-07-17 13-09-48 mem_time time_bar

The following plots are the second execution of the modified code, after the tracers were written onto disk, and the program has the ability to access the tracers on the disk, along with profile data folder and the tracer folder itself:

Jax Jit Tracer Usage Screenshot from 2024-07-17 13-15-50 mem_time time_bar

As we can see, it is a insignificant change to save the tracers onto disk, both in terms of time and memory, but now lets look at just implementing no tracers and seeing if jitizing the function improves the efficiency in any way. To do this, I will comment out the commands responsible for saving the tracers to disk. Here are the results:

Jax Jit no Tracer Screenshot from 2024-07-17 13-30-32 mem_time time_bar

Last, lets compare it with the baseline, no jit. To do this, I will comment out the partial decorator. Here are the results:

No Jax Jit Screenshot from 2024-07-17 13-45-26 mem_time time_bar


The only speedups that jax will have will be in terms of GPU usage within its module jax.numpy. The use of jax.jit is not only not beneficial, but increases memory usage due to the additional of XLA compilation and addition of tracers.