forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv_mlp.py
97 lines (77 loc) · 2.94 KB
/
conv_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Largely reusing the code from the reference VAN implementation
# see https://github.com/Visual-Attention-Network
import math
from dataclasses import dataclass
from typing import Optional
import torch.nn as nn
from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig
from . import register_feedforward
@dataclass
class ConvMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
dim_model: int
dim_model_out: Optional[int]
act_layer: Activation
dropout: float
@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
class Conv2DFeedforward(Feedforward):
"""
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
"""
def __init__(
self,
dim_model: int,
hidden_layer_multiplier: int = 1,
dim_model_out: Optional[int] = None,
activation: Activation = Activation.GeLU,
dropout=0.0,
*args,
**kwargs,
):
super().__init__()
out_features = dim_model_out or dim_model
hidden_features = hidden_layer_multiplier * dim_model
self.conv_mlp = nn.Sequential(
nn.Conv2d(dim_model, hidden_features, 1),
nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features,
),
build_activation(activation),
nn.Conv2d(hidden_features, out_features, 1),
nn.Dropout(dropout),
)
# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = True
def init_weights(self, **kwargs):
# Follow the original init, but also make it possible to initialize from the outside
def init_module(m: nn.Module):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
self.apply(init_module)
def forward(self, x):
# The conv layers expect NCHW, we have NLC by default
B, L, C = x.shape
HW = int(math.sqrt(x.shape[-2]))
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"
x = x.reshape((B, HW, HW, C)).swapdims(1, -1)
# The actual FW, including the 2d convolutions
x = self.conv_mlp(x)
# back to NLC
x = x.transpose(1, -1)
return x.flatten(1, 2)