diff --git a/README.md b/README.md index baa5f9f..c0ad5a9 100644 --- a/README.md +++ b/README.md @@ -45,3 +45,14 @@ pred = model(img) # (1, 1000) primaryClass = {cs.CV} } ``` + +```bibtex +@misc{hou2021vision, + title = {Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition}, + author = {Qibin Hou and Zihang Jiang and Li Yuan and Ming-Ming Cheng and Shuicheng Yan and Jiashi Feng}, + year = {2021}, + eprint = {2106.12368}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` diff --git a/mlp_mixer_pytorch/__init__.py b/mlp_mixer_pytorch/__init__.py index 37532b0..56d0776 100644 --- a/mlp_mixer_pytorch/__init__.py +++ b/mlp_mixer_pytorch/__init__.py @@ -1 +1,2 @@ from mlp_mixer_pytorch.mlp_mixer_pytorch import MLPMixer +from mlp_mixer_pytorch.permutator import Permutator diff --git a/mlp_mixer_pytorch/permutator.py b/mlp_mixer_pytorch/permutator.py new file mode 100644 index 0000000..7570ebe --- /dev/null +++ b/mlp_mixer_pytorch/permutator.py @@ -0,0 +1,58 @@ +from torch import nn +from functools import partial +from einops.layers.torch import Rearrange, Reduce + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.fn(self.norm(x)) + x + +class ParallelSum(nn.Module): + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + return sum(map(lambda fn: fn(x), self.fns)) + +def Permutator(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.): + assert (image_size % patch_size) == 0, 'image must be divisible by patch size' + height = width = image_size // patch_size + num_patches = (height * width) ** 2 + + return nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.Linear((patch_size ** 2) * 3, dim), + *[nn.Sequential( + PreNormResidual(dim, nn.Sequential( + ParallelSum( + nn.Sequential( + Rearrange('b h w c -> b w c h'), + nn.Linear(height, height), + Rearrange('b w c h -> b h w c'), + ), + nn.Sequential( + Rearrange('b h w c -> b h c w'), + nn.Linear(width, width), + Rearrange('b h c w -> b h w c'), + ), + nn.Linear(dim, dim) + ), + nn.Linear(dim, dim) + )), + PreNormResidual(dim, nn.Sequential( + nn.Linear(dim, dim * expansion_factor), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * expansion_factor, dim), + nn.Dropout(dropout) + )) + ) for _ in range(depth)], + nn.LayerNorm(dim), + Reduce('b h w c -> b c', 'mean'), + nn.Linear(dim, num_classes) + ) diff --git a/setup.py b/setup.py index bded924..df6ca50 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'mlp-mixer-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.9', license='MIT', description = 'MLP Mixer - Pytorch', author = 'Phil Wang',