You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
operator = lx.FunctionLinearOperator(large_function, ...)
# Option 1
out1 = operator.mv(vector1) # Traces and compiles `large_function`
out2 = operator.mv(vector2) # Traces and compiles `large_function` again!
out3 = operator.mv(vector3) # Traces and compiles `large_function` a third time!
# All that compilation might lead to long compile times.
# If `large_function` takes a long time to run, then this might also lead to long
# run times.
# Option 2
operator = lx.materialise(operator) # Traces and compiles `large_function` and
# stores the result as a matrix.
out1 = operator.mv(vector1) # Each of these just computes a matrix-vector product
out2 = operator.mv(vector2) # against the stored matrix.
out3 = operator.mv(vector3) #
# Now, `large_function` is only compiled once, and only ran once.
# However, storing the matrix might take a lot of memory, and the initial
# computation may-or-may-not take a long time to run.
In option 1, why is JAX tracing and recompiling the operator for every new input? Does this mean that while using an iterative solver, the operator is recompiled in every iteration? Is it possible to jit it only once?
The text was updated successfully, but these errors were encountered:
So JAX will trace and compile separately for every call site in your code. If you're familiar with tradtional compiled languages, then basically what is happening is that absolutely every function call is inlined.
The reason for this in large part is the way JAX does JIT compilation: it runs your code with "tracers" that record every operation that happens to them, but this means that it can't see Python-level constructs like functions or for loop and so on. See point 7 in this post.
In an iterative solver, we wrap the multiple invocations inside of a jax.lax.while_loop. As this is a JAX-level loop construct, it knows about the loop, and so it only needs to compile its body function once.
In cases like Option 1 above, then another option (other than lx.materialise) is to call either
In lineax.materialise() docs there is a following example:
In option 1, why is JAX tracing and recompiling the operator for every new input? Does this mean that while using an iterative solver, the operator is recompiled in every iteration? Is it possible to jit it only once?
The text was updated successfully, but these errors were encountered: