An easy to use and efficient implementation of xLSTM. Here are a few articles to help you understand :
- Understanding xLSTM through code implementation(pytorch)
- Implement the xLSTM paper from scratch with Pytorch
I am sorry for the potential mistakes on my docs because I am French and and don't speak english very well. And if you like this repository you can put a star.
pip install git+https://github.com/styalai/xLSTM-pytorch
import torch
import torch.nn as nn
from xLSTM.xLSTM import xLSTM as xlstm
batch_size = 4
seq_lenght = 8
input_size = 32
x_example = torch.zeros(batch_size, seq_lenght, input_size)
factor = 2 # how much input_size will be multiply to give hidden_size
depth = 4 # number of blocks for q, k and v
layers = 'ms' # m for mLSTMblock and s for sLSTMblock
model = xlstm(layers, x_example, factor=factor, depth=depth)
x = torch.randn(batch_size, seq_lenght, input_size)
out = model(x)
print(out.shape)
# torch.Size([4, 8, 32])
I test a mLSTM block (18M parameters) on an NLP task (tiny shakespeare dataset of Karpathy).
Code in 'examples/tinyshakespeare'. We can see that the model overfits a little.
If you use xlstm-pytorch in your research or projects, please cite the original xLSTM paper:
@article{Beck2024xLSTM,
title={xLSTM: Extended Long Short-Term Memory},
author={Beck, Maximilian and Pöppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, Günter and Brandstetter, Johannes and Hochreiter, Sepp},
journal={arXiv preprint arXiv:2405.04517},
year={2024}
}
The original code of 'xLSTM/utils.py' come from https://github.com/akaashdash/xlstm