Performant kernels for equivariant neural networks in Triton-lang
EquiTriton is a project that seeks to implement high-performance kernels for commonly used building blocks in equivariant neural networks, enabling compute efficient training and inference. The advantage of Triton-lang is portability across GPU architectures: kernels here have been tested against GPUs from multiple vendors, including A100/H100 from Nvidia, and the Intel®️ Data Center GPU Max Series 1550.
Our current scope includes components such as spherical harmonics (including
derivatives, up to
For users, run pip install git+https://github.com/IntelLabs/EquiTriton
. For those who
are using Intel XPUs, we recommend you reading the section on Intel XPU usage first,
and setting up an environment with PyTorch, IPEX, and Triton for XPU before installing
EquiTriton.
For developers/contributors, please clone this repository and install it in editable mode:
git clone https://github.com/IntelLabs/EquiTriton
cd EquiTriton
pip install -e './[dev]'
...which will include development dependencies such as pre-commit
(used for linting
and formatting), and jupyter
used for symbolic differentiation for kernel development.
Finally, we provide Dockerfile
s for users who prefer containers.
As a drop-in replacement for e3nn
spherical harmonics, simply include the
following in your code:
from equitriton import patch
This will dynamically replace the e3nn
spherical harmonics implementation
with the EquiTriton kernels.
There are two important things to consider before replacing:
- Numerically, there are small differences between implementations, primarily
in the backward pass. Because terms in the gradients are implemented as literals,
they can be more susceptible to rounding errors at lower precision. In most
(not all!) instances, they are numerically equivalent for
torch.float32
, and basically always different fortorch.float16
. At double precision (torch.float64
) this does not seem to be an issue, which makes it ideal for use in simulation loops but please be aware that if it is used for training, the optimization trajectory may not be exactly the same; we have not tested for divergence and encourage experimentation. - Triton kernels are compiled just-in-time and a cached every time it encounters
a new input tensor shape. In
equitriton.sph_harm.SphericalHarmonics
, thepad_tensor
argument (default isTrue
) is used to try and maximize cache re-use by padding nodes and masking in the forward pass. The scriptscripts/dynamic_shapes.py
will let you test the performance over a range of shapes; we encourage you to test it before performing full-scale training/inference.
We recently published a paper at the AI4Mat workshop at NeurIPS 2024, which as part
of that work, we went back into sympy
to refactor the spherical harmonics up to
Functionally, these kernels are intended to behave in the same way as their original implementation, i.e. they still provide equivariant properties when used to map cartesian point clouds. However, because of the aggressive refactoring and heavy use of hard-coded literals, they may (or will) differ numerically from even the initial EquiTriton kernels, particularly at higher orders.
Important
For the above reason, while the kernels can be drop-in replacements, we do not recommend using them from already trained models, at least without some testing on the user's part, as the results may differ. We have also not yet attempted to use these kernels as part of simulation-based workflows (i.e. molecular dynamics), however our training experiments do show that training indeed does converge.
To use the new set of decoupled kernels, the main torch.autograd
binding is through
the equitriton.sph_harm.direct.TritonSphericalHarmonic
:
import torch
from equitriton.sph_harm.direct import TritonSphericalHarmonic
coords = torch.rand(100, 3)
sph_harm = TritonSphericalHarmonic.apply(
l_values=[0, 1, 2, 6, 10],
coords=coords
)
The improvements to performance are expected to come from (1) decoupling of each spherical
harmonic order, and (2) pre-allocation of an output tensor as to avoid using torch.cat
,
which calculates each order followed by copying. See the "Direct spherical harmonics evaluation"
notebook in the notebooks folder for derivation.
Development on Intel XPUs such as the Data Center GPU Max Series 1550 requires a number of manual components for bare metal. The core dependency to consider is the Intel XPU backend for Triton, which will dictate the version of oneAPI, PyTorch, and Intel Extension for PyTorch to install. At the time of release, EquiTriton has been developed on the following:
- oneAPI 2024.0
- PyTorch 2.1.0
- IPEX 2.1.10+xpu
- Intel XPU backend for Triton 2.1.0
Due to the way that wheels are distributed, please install PyTorch
and IPEX per intel-requirements.txt
. Alternatively, use the provided
Docker image for development.
>>> import intel_extension_for_pytorch
>>> import torch
>>> torch.xpu.device_count()
# should be greater than zero
xpu-smi
(might not be installed) as the name suggests is the equivalent tonvidia-smi
, but with a bit more functionality based on our architecturesycl-ls
is provided by thedpcpp
runtime, and lists out all devices that are OpenCL and SYCL capable. Notably this can be used to quickly check how many GPUs are available.- pti-gpu provides a set of tools that you can compile for profiling. Notably,
unitrace
andoneprof
allows you do to low-level profiling for the device.
We welcome contributions from the open-source community! If you have any questions or suggestions, feel free to create an issue in our repository. We will be happy to work with you to make this project even better.
The code and documentation in this repository are licensed under the Apache 2.0 license. By contributing to this project, you agree that your contributions will be licensed under this license.
If you find this repo useful, please consider citing the respective papers.
For the original EquiTriton implementation, please use/read the following citation:
@inproceedings{lee2024scaling,
title={Scaling Computational Performance of Spherical Harmonics Kernels with Triton},
author={Kin Long Kelvin Lee and Mikhail Galkin and Santiago Miret},
booktitle={AI for Accelerated Materials Design - Vienna 2024},
year={2024},
url={https://openreview.net/forum?id=ftK00FO5wq}
}
For the refactored spherical harmonics up to
@inproceedings{lee2024deconstructing,
title={Deconstructing equivariant representations in molecular systems},
author={Kin Long Kelvin Lee and Mikhail Galkin and Santiago Miret},
booktitle={AI for Accelerated Materials Design - NeurIPS 2024},
year={2024},
url={https://openreview.net/forum?id=pshyLoyzRn}
}