Implementation of the 3D Cartesian sinc DVR basis from the paper by Jones et al..
To test locally different device geometries, the XLA_FLAGS
from
jax
can be used.
For example, to emulate 9 devices do
export XLA_FLAGS='--xla_force_host_platform_device_count=9'
or, alternatively at the top of the script with
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=9'
As Jax requires a specification of the device to be used, we have included this
in the installation of sinc-dvr
as well.
For a local installation using jaxlib on the CPU this consists of:
pip install -e ".[cpu]"
Instead of cpu
the same flags as for jax can be specified.
For example, using CUDA version 12 with binaries built from pip can be installed via:
pip install -e ".[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Here the link after -f
is needed by Jax to find the correct CUDA components.
To run the tests call:
python -m unittest
Note that the tests emulate several devices using the XLA_FLAGS
environment
variable from above.
This does not seem to supported on the GPU, so the tests should be run using
the CPU.