This is the official PyTorch implementation of SlotSSMs presented in the paper:
Slot State Space Models
Jindong Jiang, Fei Deng, Gautam Singh, Minseung Lee, Sungjin Ahn
🌟 NeurIPS 2024 🌟
Project page: https://slotssms.github.io/
NeurIPS Submission: https://openreview.net/forum?id=BJv1t4XNJW
SlotSSMs propose an efficient and powerful framework for video understanding by incorporating independent mechanisms into State Space Models (SSMs), such as Mamba, to encourage separation of information for each entity in a scene, thereby improving visual reasoning. Unlike conventional SSMs that maintain a monolithic state vector, SlotSSMs maintain a set of modular states whose dynamics are designed to be independent in time, with interactions modeled through self-attention in space.
By adopting SSMs at its core, SlotSSMs inherit their strengths: parallelizable training, memory efficiency, and long-range reasoning capabilities, giving it an advantage over methods based on RNNs and Transformers.
SlotSSMs vs existing models. (a) SlotSSMs incorporate modularity through independent state transitions and sparse interactions via self-attention. (b) Traditional SSMs utilize a monolithic state vector for all past information. (c) Multi-slot Transformer-based models offer modularity but with high computational complexity. (d) Multi-slot RNN-based models have modular states but can't parallelize training (red lock). SlotSSMs combine parallelizable training, memory efficiency, and modularity for efficient temporal modeling.
In our experiments, we evaluate SlotSSMs in long-context video understanding, video prediction, object-centric learning, 3D visual reasoning, and more. Below, we show qualitative results of the depth prediction task to showcase the emerging modularity in SlotSSMs for real-world video inputs.
TikTok Dataset
Waymo Dataset
UT Egocentric Dataset
Emergent Scene Decomposition from Depth Estimation Tasks. Colors represent the ID of slots used for predicting each position. SlotSSM is capable of exploiting the inherent modular structure of real-world videos for efficient inference, without explicit segmentation supervision.
This repository is a re-implementation of SlotSSMs with supports for Mamba 1 and Mamba 2 in temporal modeling and FlashAttention in spatial attention. To facilitate model saving and loading, we base the model implementation on Hugging Face’s PreTrainedModel class and model configuration on PretrainedConfig. We use Hugging Face's Accelerate library to manage distributed and mixed precision training.
The project structure is shown below:
train.py
: An example of model training using the CLEVRER dataset on a reconstruction task.src
: Source codedata
: Data loading utilitiesmodels
: Core model definitionsencoder.py
: Image encoder, we provide ViT-based and CNN-based encoderslotssm.py
: SlotSSM model definitiondecoder.py
: Image decoder, we provide ViT-based and CNN-based decodermodules.py
: Eager implementation of MultiHeadAttention (MHA) and Inverted MHA
utils
: Utility functions
scripts
: Helper scriptsenvironment.sh
: Environment setup
Setup your environment with our provided script in scripts/environment.sh
. It includes the following steps:
Create and activate a conda environment (Python 3.10 is used here):
conda create -n "slotssm" python=3.10 -y
conda activate slotssm
Install pip toolkit and PyTorch library (select the CUDA version compatible with your system):
conda install pip -y
pip install --upgrade pip
conda install -c nvidia cuda-toolkit -y
pip install torch torchvision torchaudio
(Optional) Install FlashAttention:
pip install flash-attn --no-build-isolation
Install Mamba dependencies:
pip install git+https://github.com/Dao-AILab/causal-conv1d
pip install git+https://github.com/state-spaces/mamba
Other packages are installed using the following command.
pip install transformers accelerate decord
We use the CLEVRER dataset as an example. Download the dataset from the official website and unzip it.
mkdir -p data/clevrer/videos/train/ && \
cd data/clevrer/videos/train/ && \
wget http://data.csail.mit.edu/clevrer/videos/train/video_train.zip && \
unzip video_train.zip
After setting up the environment and preparing the data, you can train the model using train.py
. We provide an example of representation learning on the CLEVRER dataset through video reconstruction.
For systems with limited GPU memory, adjust the gradient_accumulation_steps
and batch_size
. The actually batch size would be gradient_accumulation_steps x num_processes x batch_size
. You may also use mixed precision training to save memory.
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes=4 --main_process_port 29500 train.py --seq_len 16 \
--train_data_path /path/to/your/dataset --mixed_precision fp16 --gradient_accumulation_steps 1 --batch_size 8
If you find this code useful for your research, please cite our paper with the following BibTeX entry
@article{jiang2024slot,
title={Slot State Space Models},
author={Jiang, Jindong and Deng, Fei and Singh, Gautam and Lee, Minseung and Ahn, Sungjin},
journal={NeurIPS},
year={2024}
}