Official code for the paper Provable Compositional Generalization for Object-Centric Learning (openreview, arXiv).
We formalize compositional generalization as an identifiability problem for a latent variable model where objects are represented by latent slots. Here, compositional generalization requires two things:
- Identifying the ground-truth latent slots in-distribution (blue).
- Generalizing this behavior to unseen, out-of-distribution slot combinations (grey).
From an empirical side, the main contribution is the compositional consisteny loss outlined below:
This code was tested for Python 3.10.
You can start by cloning the repository:
git clone [email protected]:brendel-group/objects-compositional-generalization.git
cd objects-compositional-generalization
Then, set up your environment by choosing one of the following methods:
Option 1: Installing Dependencies Directly
pip install -r requirements.txt
Or, alternatively, you can use Docker:
Option 2: Building a Docker Image
Build and run a Docker container using the provided Dockerfile:
docker build -t object_centric_ood .
docker-compose up
🔗 For understanding how the data looks and to play with the data generation, please refer to the notebooks/0. Sprite-World Dataset Example.ipynb
notebook.
🔗 For the actual data generation, please refer to the notebooks/1. Data Generation.ipynb
notebook. The folder used for saving the dataset at this point will be used for training and evaluation.
To train the model, run the following command:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --epochs 400 --use_consistency_loss True
For complete details on the parameters, please refer to the main.py
file.
You can find some example commands for training below:
Different Training Setups
-
Training SlotAttention
Training vanilla SlotAttention with 2 slots:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss False
Training vanilla SlotAttention with 2 slots and consistency loss:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss True --consistency_ignite_epoch 150
Training SlotAttention with 2 slots, fixed SoftMax and sampling:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss True --consistency_ignite_epoch 150 --softmax False --sampling False
-
Training AE Model
Training vanilla autoencoder with 2 slots:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotMLPAdditive" --epochs 300 --num_slots 2 -n_slot_latents 6 --use_consistency_loss False
Training vanilla autoencoder with 2 slots and consistency loss:
python main.py --dataset_path "/path/from/previous/step" --model_name "SlotMLPAdditive" --epochs 300 --num_slots 2 -n_slot_latents 6 --use_consistency_loss True --consistency_ignite_epoch 100
Evaluation can be done using the evaluate.py
script and closely follows the procedure and metrics used in the training script. The main difference is in calculating the compositional contrast (note: it might cause OOM issues, thus is calculated only for the AE model).
Here is an example command for evaluation:
python src/evaluation.py --dataset_path "/path/from/previous/step" --model_path "checkpoints/SlotMLPAdditive.pt" --model_name "SlotMLPAdditive" --n_slot_latents 6
Misc
🔗 notebooks/2. Decoder Optimality.ipynb
shows how the experiment described in Appendix B.2 is conducted. Note that the paths to models' checkpoints and exact data splits are omitted, thus it serves more of an illustrative purpose.
If you make use of this code in your own work, please cite our paper:
@inproceedings{wiedemer2024provable,
title={Provable Compositional Generalization for Object-Centric Learning},
author={Thadd{\"a}us Wiedemer and Jack Brady and Alexander Panfilov and Attila Juhos and Matthias Bethge and Wieland Brendel},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=7VPTUWkiDQ}
}