Official repository for the NeurIPS'24 paper "Scale Equivariant Graph Metanetworks" by Ioannis Kalogeropoulos*, Giorgos Bouritsas* and Yannis Panagakis.
[arXiv]
This paper pertains to an emerging machine learning paradigm: learning higher-order functions, i.e. functions whose inputs are functions themselves, particularly when these inputs are Neural Networks (NNs). With the growing interest in architectures that process NNs, a recurring design principle has permeated the field: adhering to the permutation symmetries arising from the connectionist structure of NNs. However, are these the sole symmetries present in NN parameterizations? Zooming into most practical activation functions (e.g. sine, ReLU, tanh) answers this question negatively and gives rise to intriguing new symmetries, which we collectively refer to as scaling symmetries, that is, non-zero scalar multiplications and divisions of weights and biases. In this work, we propose Scale Equivariant Graph MetaNetworks - ScaleGMNs, a framework that adapts the Graph Metanetwork (message-passing) paradigm by incorporating scaling symmetries and thus rendering neuron and edge representations equivariant to valid scalings. We introduce novel building blocks, of independent technical interest, that allow for equivariance or invariance with respect to individual scalar multipliers or their product and use them in all components of ScaleGMN. Furthermore, we prove that, under certain expressivity conditions, ScaleGMN can simulate the forward and backward pass of any input feedforward neural network. Experimental results demonstrate that our method advances the state-of-the-art performance for several datasets and activation functions, highlighting the power of scaling symmetries as an inductive bias for NN processing.
To create a clean virtual environment and install the necessary dependencies execute:
git clone [email protected]:jkalogero/scalegmn.git
cd scalegmn/
conda env create -n scalegmn --file environment.yml
conda activate scalegmn
First, create the data/
directory in the root of the repository:
mkdir data
Alternatively, you can specify a different directory for the data by changing the corresponding fields in the config file.
For the INR datasets, we use the data provided by DWS and NFN. The datasets can be downloaded from the following links:
- MNIST-INRs - (Navon et al. 2023)
- FMNIST-INRs - (Navon et al. 2023)
- CIFAR10-INRs - (Zhou et al. 2023)
Download the datasets and extract them in the directory data/
. For example, you can run the following to download
and extract the MNIST-INR dataset and generate the splits:
DATA_DIR=./data
wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AABrctdu2U65jGYr2WQRzmMna/mnist-inrs.zip?dl=0" -O "$DATA_DIR/mnist-inrs.zip"
unzip -q "$DATA_DIR/mnist-inrs.zip" -d "$DATA_DIR"
rm "$DATA_DIR/mnist-inrs.zip" # remove the zip file
# generate the splits
python src/utils/generate_data_splits.py --data_path $DATA_DIR/mnist-inrs --save_path $DATA_DIR/mnist-inrs
Generating the splits is necessary only for the MNIST-INR dataset.
For the INR datasets, we preprocess each datapoint to canonicalize the phase symmetry (see Algorithm 1 in the appendix). To run the phase canonicalization script, run the following command:
python src/phase_canonicalization/canonicalization.py --conf src/phase_canonicalization/<dataset>.yml
where <dataset>
can be one of mnist
, fmnist
, cifar
.
To apply the canonicalization to the augmented CIFAR10-INR dataset, also run:
python src/phase_canonicalization/canonicalization.py --conf src/phase_canonicalization/cifar.yml --extra_aug 20
The above script will store the canonicalized dataset in a new directory data/<dataset>_canon/
. The training scripts will automatically use the canonicalized dataset, if it exists.
To use the dataset specified in the config file (and not search for data/<dataset>_canon/
), set the data.switch_to_canon
field of the config to False
or simply use the CLI argument --data.switch_to_canon False
.
We follow the experiments from NFN and use the datasets provided by Unterthiner et al, 2020. The datasets can be downloaded from the following links:
Similarly, extract the dataset in the directory data/
and execute:
For the CIFAR10 dataset:
tar -xvf cifar10.tar.xz
# download cifar10 splits
wget https://github.com/AllanYangZhou/nfn/raw/refs/heads/main/experiments/predict_gen_data_splits/cifar10_split.csv -O data/cifar10/cifar10_split.csv
For the SVHN dataset:
tar -xvf svhn_cropped.tar.xz
# download svhn splits
wget https://github.com/AllanYangZhou/nfn/raw/refs/heads/main/experiments/predict_gen_data_splits/svhn_split.csv -O data/svhn_cropped/svhn_split.csv
For every experiment, we provide the corresponding configuration file in the config/
directory.
Each config contains the selected hyperparameters for the experiment, as well as the paths to the dataset.
To enable wandb logging, use the CLI argument --wandb True
. For more useful CLI arguments, check the src/utils/setup_arg_parser.py file.
Note: To employ a GMN accounting only for the permutation symmetries, simply set
--scalegmn_args.symmetry=permutation
.
To train and evaluate ScaleGMN on the INR classification task, select any config file under configs/mnist_cls , configs/fmnist_cls or configs/cifar_inr_cls. For example, to train ScaleGMN on the FMNIST-INR dataset, execute the following:
python inr_classification.py --conf configs/fmnist_cls/scalegmn.yml
To train and evaluate ScaleGMN on the INR editing task, use the configs under configs/mnist_editing directory and execute:
python inr_editing.py --conf configs/mnist_editing/scalegmn_bidir.yml
To train and evaluate ScaleGMN on the INR classification task, select any config file under configs/cifar10 or configs/svhn. For example, to train ScaleGMN on the CIFAR10 dataset on heterogeneous activation functions, execute the following:
python predicting_generalization.py --conf configs/cifar10/scalegmn_hetero.yml
@article{kalogeropoulos2024scale,
title={Scale Equivariant Graph Metanetworks},
author={Kalogeropoulos, Ioannis and Bouritsas, Giorgos and Panagakis, Yannis},
journal={Advances in Neural Information Processing Systems},
year={2024}
}