Skip to content

Commit

Permalink
add permutator
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 24, 2021
1 parent d6901e7 commit dc1efbb
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ pred = model(img) # (1, 1000)
primaryClass = {cs.CV}
}
```

```bibtex
@misc{hou2021vision,
title = {Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition},
author = {Qibin Hou and Zihang Jiang and Li Yuan and Ming-Ming Cheng and Shuicheng Yan and Jiashi Feng},
year = {2021},
eprint = {2106.12368},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
1 change: 1 addition & 0 deletions mlp_mixer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from mlp_mixer_pytorch.mlp_mixer_pytorch import MLPMixer
from mlp_mixer_pytorch.permutator import Permutator
58 changes: 58 additions & 0 deletions mlp_mixer_pytorch/permutator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x):
return self.fn(self.norm(x)) + x

class ParallelSum(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)

def forward(self, x):
return sum(map(lambda fn: fn(x), self.fns))

def Permutator(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
height = width = image_size // patch_size
num_patches = (height * width) ** 2

return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * 3, dim),
*[nn.Sequential(
PreNormResidual(dim, nn.Sequential(
ParallelSum(
nn.Sequential(
Rearrange('b h w c -> b w c h'),
nn.Linear(height, height),
Rearrange('b w c h -> b h w c'),
),
nn.Sequential(
Rearrange('b h w c -> b h c w'),
nn.Linear(width, width),
Rearrange('b h c w -> b h w c'),
),
nn.Linear(dim, dim)
),
nn.Linear(dim, dim)
)),
PreNormResidual(dim, nn.Sequential(
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * expansion_factor, dim),
nn.Dropout(dropout)
))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b h w c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mlp-mixer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.9',
license='MIT',
description = 'MLP Mixer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit dc1efbb

Please sign in to comment.