Skip to content

Official code for the paper "Provable Compositional Generalization for Object-Centric Learning" (ICLR 2024, oral)

License

Notifications You must be signed in to change notification settings

brendel-group/objects-compositional-generalization

Repository files navigation

Provable Compositional Generalization for Object-Centric Learning [ICLR 2024]

Official code for the paper Provable Compositional Generalization for Object-Centric Learning (openreview, arXiv).

We formalize compositional generalization as an identifiability problem for a latent variable model where objects are represented by latent slots. Here, compositional generalization requires two things:

  1. Identifying the ground-truth latent slots in-distribution (blue).
  2. Generalizing this behavior to unseen, out-of-distribution slot combinations (grey).

Overview

From an empirical side, the main contribution is the compositional consisteny loss outlined below:

Problem Setup

Environment Setup

This code was tested for Python 3.10.

You can start by cloning the repository:

git clone [email protected]:brendel-group/objects-compositional-generalization.git
cd objects-compositional-generalization

Then, set up your environment by choosing one of the following methods:

Option 1: Installing Dependencies Directly
pip install -r requirements.txt

Or, alternatively, you can use Docker:

Option 2: Building a Docker Image

Build and run a Docker container using the provided Dockerfile:

docker build -t object_centric_ood .
docker-compose up

Data Generation

🔗 For understanding how the data looks and to play with the data generation, please refer to the notebooks/0. Sprite-World Dataset Example.ipynb notebook.

🔗 For the actual data generation, please refer to the notebooks/1. Data Generation.ipynb notebook. The folder used for saving the dataset at this point will be used for training and evaluation.

Training and Evaluation

Training

To train the model, run the following command:

python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --epochs 400 --use_consistency_loss True

For complete details on the parameters, please refer to the main.py file.

You can find some example commands for training below:

Different Training Setups
  • Training SlotAttention

    Training vanilla SlotAttention with 2 slots:

    python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss False

    Training vanilla SlotAttention with 2 slots and consistency loss:

    python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss True --consistency_ignite_epoch 150

    Training SlotAttention with 2 slots, fixed SoftMax and sampling:

    python main.py --dataset_path "/path/from/previous/step" --model_name "SlotAttention" --num_slots 2 --use_consistency_loss True --consistency_ignite_epoch 150 --softmax False --sampling False
  • Training AE Model

    Training vanilla autoencoder with 2 slots:

    python main.py --dataset_path "/path/from/previous/step" --model_name "SlotMLPAdditive" --epochs 300 --num_slots 2 -n_slot_latents 6 --use_consistency_loss False

    Training vanilla autoencoder with 2 slots and consistency loss:

    python main.py --dataset_path "/path/from/previous/step" --model_name "SlotMLPAdditive" --epochs 300 --num_slots 2 -n_slot_latents 6 --use_consistency_loss True --consistency_ignite_epoch 100

Evaluation

Evaluation can be done using the evaluate.py script and closely follows the procedure and metrics used in the training script. The main difference is in calculating the compositional contrast (note: it might cause OOM issues, thus is calculated only for the AE model).

Here is an example command for evaluation:

python src/evaluation.py --dataset_path "/path/from/previous/step" --model_path "checkpoints/SlotMLPAdditive.pt" --model_name "SlotMLPAdditive" --n_slot_latents 6
Misc

🔗 notebooks/2. Decoder Optimality.ipynb shows how the experiment described in Appendix B.2 is conducted. Note that the paths to models' checkpoints and exact data splits are omitted, thus it serves more of an illustrative purpose.

Citation

If you make use of this code in your own work, please cite our paper:

@inproceedings{wiedemer2024provable,
    title={Provable Compositional Generalization for Object-Centric Learning},
    author={Thadd{\"a}us Wiedemer and Jack Brady and Alexander Panfilov and Attila Juhos and Matthias Bethge and Wieland Brendel},
    booktitle={The Twelfth International Conference on Learning Representations},
    year={2024},
    url={https://openreview.net/forum?id=7VPTUWkiDQ}
}

About

Official code for the paper "Provable Compositional Generalization for Object-Centric Learning" (ICLR 2024, oral)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published