Skip to content

Latest commit

 

History

History
67 lines (46 loc) · 2.76 KB

README.md

File metadata and controls

67 lines (46 loc) · 2.76 KB

mamba2-minimal

A minimal, single-file implementation of the Mamba-2 model in PyTorch.

Mamba-2

Transformers are SSMs: Generalized Models and Efficient Algorithms
Through Structured State Space Duality
Tri Dao*, Albert Gu*
Paper: https://arxiv.org/abs/2405.21060

Mamba is a new class of foundation models, most notable for not being based on the Transformer architecture. Instead it is in the family of State Space Models (SSMs) that maps a sequence through a hidden state in the fashion of RNNs. This approach enables linear scaling in computation and memory with respect to sequence length during training (unlike transformer's quadratic complexity), as well as constant time per step during inference. Mamba-2 builds upon Mamba-1 by imposing additional constraints on certain SSM parameters, allowing it to have much larger state dimensions and significantly improved training speed.

This implementation is device agnostic and have been tested to work on the CPU and MPS (Metal Performance Shaders) backends. The model's output logits follow the same distribution as the reference implementation but are not numerically equivalent.

Usage

Install dependencies (torch, einops and transformers):

pip install -r requirements.txt

See demo.ipynb for using Mamba-2 as part of an end-to-end language model with pretrained weights for text generation.

The core Mamba-2 model can be used as follows:

import torch

from mamba2 import Mamba2, Mamba2Config

config = Mamba2Config(d_model=768)
model = Mamba2(config)

x = torch.randn(2, 64, 768)  # (batch, seqlen, d_model)
y = model(x)  # same shape as x

TODOs

  • Constant time (wrt sequence length) autoregressive inference
  • Remove dependency on einops (depends on whether resulting code is still readable)

Credits

Resources

Some resources to understand Mamba and SSMs.