Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we make vectorized function compatible to JAX' numpy interface? #310

Open
AlexanderFabisch opened this issue Nov 27, 2024 · 2 comments

Comments

@AlexanderFabisch
Copy link
Member

This would make JIT compilation and GPU parallelization possible.

@kopytjuk
Copy link
Contributor

kopytjuk commented Dec 6, 2024

Today I experimented a little with JAX trying to port pytransform3d.rotations._utils.norm_vector and it looks like that we have to adjust the existing code (see below). I achieved a 2x speedup. There is a complete guide with potential pitfalls when porting functions to JAX, so at least it is well documented what to do.

The norm_vector was only an easy function, for me it feels like, that adapting functions like dual_quaternions_from_screw_parameters is a lot of work. I think we first need to identify what functions are the slowest in the library, and JAX-rewrite/compile them. Rewriting everything to JAX is a very large and little-use effort imho.

# test_jax.py
import timeit

import jax.numpy as jnp
import numpy as np
from jax import jit

from pytransform3d.rotations._utils import norm_vector as norm_vector_original


# same function would work with `np.linalg` or `np.where`
def norm_vector_jax(v):
    """Normalize vector.

    Parameters
    ----------
    v : array-like, shape (n,)
        nd vector

    Returns
    -------
    u : array, shape (n,)
        nd unit vector with norm 1 or the zero vector
    """

    norm = jnp.linalg.norm(v, axis=0, ord=2)
    result = jnp.where(norm == 0.0, v, jnp.asarray(v) / norm)
    return result


if __name__ == "__main__":
    N = 1000000  # a large array
    v = np.random.randn(N)

    # JIT compilation here
    norm_vector_jax_jit = jit(norm_vector_jax)

    print("Original:", norm_vector_original(v))
    print("No-JIT:", norm_vector_jax(v))
    print("JIT:", norm_vector_jax_jit(v))

    def original_func(): return norm_vector_original(v)
    def jax_func(): return norm_vector_jax(v)
    def jax_jit_func(): return norm_vector_jax_jit(v)

    num_executions = 100
    t_original = timeit.timeit(
        original_func, number=num_executions) / num_executions
    t_jax = timeit.timeit(jax_func, number=num_executions) / num_executions
    t_jax_jit = timeit.timeit(
        jax_jit_func, number=num_executions) / num_executions

    print("\nBenchmark:")
    print("Original:", t_original)
    print("No-JIT", t_jax)
    print("JIT", t_jax_jit)

JIT version brings 2x performance of numpy in my small benchmark.

Original: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]
No-JIT: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]
JIT: [-0.0013723   0.001101   -0.00120956 ...  0.00043465  0.00133065
 -0.00048264]

Benchmark:
Original: 0.0011862749996362253
No-JIT 0.0016969608300132677
JIT 0.0006298654095735401

@AlexanderFabisch
Copy link
Member Author

Rewriting everything to JAX is a very large and little-use effort imho.

You are right.

And it will be hard to implement more complex functionality with JAX support though. In this case every operation involved would have to be JAX compatible (e.g., ScLERP, which uses a lot of basic vectorized functions).

I also tried to convert some of the more complex functions. The main problem are operations like

A[..., :3, :3] = ...

Instead of doing this, I'd recommend to compute the parts and concatenate them later. Not sure how well that works with conditional code though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants