Skip to content

Official Release of NeurIPS 2024 paper "Slot State Space Models"

License

Notifications You must be signed in to change notification settings

JindongJiang/SlotSSMs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SlotSSMs: Slot State Space Models

This is the official PyTorch implementation of SlotSSMs presented in the paper:

Slot State Space Models
Jindong Jiang, Fei Deng, Gautam Singh, Minseung Lee, Sungjin Ahn
🌟 NeurIPS 2024 🌟
Project page: https://slotssms.github.io/
NeurIPS Submission: https://openreview.net/forum?id=BJv1t4XNJW

Highlights

SlotSSMs propose an efficient and powerful framework for video understanding by incorporating independent mechanisms into State Space Models (SSMs), such as Mamba, to encourage separation of information for each entity in a scene, thereby improving visual reasoning. Unlike conventional SSMs that maintain a monolithic state vector, SlotSSMs maintain a set of modular states whose dynamics are designed to be independent in time, with interactions modeled through self-attention in space.

By adopting SSMs at its core, SlotSSMs inherit their strengths: parallelizable training, memory efficiency, and long-range reasoning capabilities, giving it an advantage over methods based on RNNs and Transformers.

SlotSSM Method Overview

SlotSSMs vs existing models. (a) SlotSSMs incorporate modularity through independent state transitions and sparse interactions via self-attention. (b) Traditional SSMs utilize a monolithic state vector for all past information. (c) Multi-slot Transformer-based models offer modularity but with high computational complexity. (d) Multi-slot RNN-based models have modular states but can't parallelize training (red lock). SlotSSMs combine parallelizable training, memory efficiency, and modularity for efficient temporal modeling.

In our experiments, we evaluate SlotSSMs in long-context video understanding, video prediction, object-centric learning, 3D visual reasoning, and more. Below, we show qualitative results of the depth prediction task to showcase the emerging modularity in SlotSSMs for real-world video inputs.

Video Decomposition on TikTok Dataset

TikTok Dataset

Video Decomposition on Waymo Dataset

Waymo Dataset

Video Decomposition on UT Egocentric Dataset

UT Egocentric Dataset

Emergent Scene Decomposition from Depth Estimation Tasks. Colors represent the ID of slots used for predicting each position. SlotSSM is capable of exploiting the inherent modular structure of real-world videos for efficient inference, without explicit segmentation supervision.

Repository Overview

This repository is a re-implementation of SlotSSMs with supports for Mamba 1 and Mamba 2 in temporal modeling and FlashAttention in spatial attention. To facilitate model saving and loading, we base the model implementation on Hugging Face’s PreTrainedModel class and model configuration on PretrainedConfig. We use Hugging Face's Accelerate library to manage distributed and mixed precision training.

The project structure is shown below:

  • train.py: An example of model training using the CLEVRER dataset on a reconstruction task.
  • src: Source code
    • data: Data loading utilities
    • models: Core model definitions
      • encoder.py: Image encoder, we provide ViT-based and CNN-based encoder
      • slotssm.py: SlotSSM model definition
      • decoder.py: Image decoder, we provide ViT-based and CNN-based decoder
      • modules.py: Eager implementation of MultiHeadAttention (MHA) and Inverted MHA
    • utils: Utility functions
  • scripts: Helper scripts
    • environment.sh: Environment setup

Setup and Usage

Dependencies

Setup your environment with our provided script in scripts/environment.sh. It includes the following steps:

Create and activate a conda environment (Python 3.10 is used here):

conda create -n "slotssm" python=3.10 -y 
conda activate slotssm

Install pip toolkit and PyTorch library (select the CUDA version compatible with your system):

conda install pip -y
pip install --upgrade pip
conda install -c nvidia cuda-toolkit -y
pip install torch torchvision torchaudio

(Optional) Install FlashAttention:

pip install flash-attn --no-build-isolation

Install Mamba dependencies:

pip install git+https://github.com/Dao-AILab/causal-conv1d
pip install git+https://github.com/state-spaces/mamba

Other packages are installed using the following command.

pip install transformers accelerate decord

Preparing Data

We use the CLEVRER dataset as an example. Download the dataset from the official website and unzip it.

mkdir -p data/clevrer/videos/train/ && \
cd data/clevrer/videos/train/ && \
wget http://data.csail.mit.edu/clevrer/videos/train/video_train.zip && \
unzip video_train.zip

Training

After setting up the environment and preparing the data, you can train the model using train.py. We provide an example of representation learning on the CLEVRER dataset through video reconstruction.

For systems with limited GPU memory, adjust the gradient_accumulation_steps and batch_size. The actually batch size would be gradient_accumulation_steps x num_processes x batch_size. You may also use mixed precision training to save memory.

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes=4 --main_process_port 29500 train.py --seq_len 16 \
--train_data_path /path/to/your/dataset --mixed_precision fp16 --gradient_accumulation_steps 1 --batch_size 8

Citation

If you find this code useful for your research, please cite our paper with the following BibTeX entry

@article{jiang2024slot,
  title={Slot State Space Models},
  author={Jiang, Jindong and Deng, Fei and Singh, Gautam and Lee, Minseung and Ahn, Sungjin},
  journal={NeurIPS},
  year={2024}
}

About

Official Release of NeurIPS 2024 paper "Slot State Space Models"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published