Skip to content

Guest400123064/ezgatr

Repository files navigation

Geometric Algebra Transformer Made Easy

License: MIT test codecov Ruff types - Mypy Website PyPI Python 3.10 DOI

What is EzGATr?

EzGATr (Easy Geometric Algebra Transformer) intends to be a simple-to-use and lightweight Python library for building 3D Geometric Algebra Transformers (GATr). It is a collection of operators, modules, utilities, etc. build on top of PyTorch. In addition, EzGATr also seeks to bridge the gap between the mathematical formulations and corresponding implementations through extensive documentation and explanations to facilitate learning and potential future optimizations.

Installation

EzGATr is currently under active development. To install it, you can clone the repository and install it using pip. Use the -e flag to install it in editable mode for making quick changes.

$ git clone https://github.com/Guest400123064/ezgatr.git
$ cd ezgatr
$ pip install -e .

Alternatively, one can directly install it from PyPI.

$ pip install ezgatr

Usage

The quick start example shown below demonstrates how to use EzGATr to build a equivariant network with one input/output channel (e.g., a point cloud) and four hidden channels. The mock input data contains a batch of eight samples, each with 256 3D objects embedded as multi-vectors.

import torch
import torch.nn as nn

from ezgatr.nn import EquiLinear
from ezgatr.nn.functional import scaler_gated_gelu


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = EquiLinear(1, 4)
        self.fc2 = EquiLinear(4, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = scaler_gated_gelu(x)
        x = self.fc2(x)
        return x


net = SimpleNet()
in_ = torch.randn(8, 256, 1, 16)
out = net(in_)

One can refer to this example for how to build a full-fledged GATr model with EzGATr, involving equivariant geometric attention and geometric MLP.

API References

The complete API references for EzGATr can be found here. TL;DR, the package is organized as follows:

  • ezgatr.nn: Contains the core modules and layers for building GATr models. It is organized similarly to PyTorch's torch.nn package, where the functional submodule contains lower-level operators and transformations.
  • ezgatr.interfaces: Contains the utility functions that help encode and decode 3D objects to and from multi-vectors.
  • ezgatr.nets: Contains off-of-the-shelf networks built with EzGATr building blocks. It can also be used as references for building custom networks with EzGATr.

Citation and Authors

If you find EzGATr useful in your research, please consider citing it using the following BibTeX entry:

@misc{ezgatr2024,
  author = {Yuxuan Wang, Xiatao Sun},
  title = {EzGATr: Geometric Algebra Transformer Made Easy},
  year = 2024,
  publisher = {Zenodo},
  version = {v0.1.0-alpha},
  doi = {10.5281/zenodo.13920438},
  url = {https://doi.org/10.5281/zenodo.13920438}
}

EzGATr leveraged the work from GATr (official GitHub repository) by Johann Brehmer, Pim de Haan, Sönke Behrends, and Taco Cohen, published at NeurIPS 2023. Please consider citing their work as well.

@inproceedings{NEURIPS2023_6f6dd92b,
  author = {Brehmer, Johann and de Haan, Pim and Behrends, S\"{o}nke and Cohen, Taco S},
  booktitle = {Advances in Neural Information Processing Systems},
  editor = {A. Oh and T. Naumann and A. Globerson and K. Saenko and M. Hardt and S. Levine},
  pages = {35472--35496},
  publisher = {Curran Associates, Inc.},
  title = {Geometric Algebra Transformer},
  url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/6f6dd92b03ff9be7468a6104611c9187-Paper-Conference.pdf},
  volume = {36},
  year = {2023}
}

License

EzGATr is distributed under the terms of the MIT license.