To enable GPU you need installed and working cuda drivers. To enable cuda in jax you need to install:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
In order to enable tree visualization with graphviz you need to
sudo apt-get install libgraphviz-dev graphviz
and
pip install pygraphviz