-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we make vectorized function compatible to JAX' numpy interface? #310
Comments
Today I experimented a little with JAX trying to port The # test_jax.py
import timeit
import jax.numpy as jnp
import numpy as np
from jax import jit
from pytransform3d.rotations._utils import norm_vector as norm_vector_original
# same function would work with `np.linalg` or `np.where`
def norm_vector_jax(v):
"""Normalize vector.
Parameters
----------
v : array-like, shape (n,)
nd vector
Returns
-------
u : array, shape (n,)
nd unit vector with norm 1 or the zero vector
"""
norm = jnp.linalg.norm(v, axis=0, ord=2)
result = jnp.where(norm == 0.0, v, jnp.asarray(v) / norm)
return result
if __name__ == "__main__":
N = 1000000 # a large array
v = np.random.randn(N)
# JIT compilation here
norm_vector_jax_jit = jit(norm_vector_jax)
print("Original:", norm_vector_original(v))
print("No-JIT:", norm_vector_jax(v))
print("JIT:", norm_vector_jax_jit(v))
def original_func(): return norm_vector_original(v)
def jax_func(): return norm_vector_jax(v)
def jax_jit_func(): return norm_vector_jax_jit(v)
num_executions = 100
t_original = timeit.timeit(
original_func, number=num_executions) / num_executions
t_jax = timeit.timeit(jax_func, number=num_executions) / num_executions
t_jax_jit = timeit.timeit(
jax_jit_func, number=num_executions) / num_executions
print("\nBenchmark:")
print("Original:", t_original)
print("No-JIT", t_jax)
print("JIT", t_jax_jit) JIT version brings 2x performance of numpy in my small benchmark.
|
You are right. And it will be hard to implement more complex functionality with JAX support though. In this case every operation involved would have to be JAX compatible (e.g., ScLERP, which uses a lot of basic vectorized functions). I also tried to convert some of the more complex functions. The main problem are operations like A[..., :3, :3] = ... Instead of doing this, I'd recommend to compute the parts and concatenate them later. Not sure how well that works with conditional code though. |
This would make JIT compilation and GPU parallelization possible.
The text was updated successfully, but these errors were encountered: