-
Notifications
You must be signed in to change notification settings - Fork 1
/
nnmamba.py
243 lines (198 loc) · 8.58 KB
/
nnmamba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding."""
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution."""
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, mamba_layer=None):
super(BasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm3d(planes)
self.mamba_layer = mamba_layer
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.mamba_layer is not None:
global_att = self.mamba_layer(x)
out += global_att
if self.downsample is not None:
# if self.mamba_layer is not None:
# global_att = self.mamba_layer(x)
# identity = self.downsample(x+global_att)
# else:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
def make_res_layer(inplanes, planes, blocks, stride=1, mamba_layer=None):
downsample = nn.Sequential(
conv1x1(inplanes, planes, stride),
nn.BatchNorm3d(planes),
)
layers = []
layers.append(BasicBlock(inplanes, planes, stride, downsample))
for _ in range(1, blocks):
layers.append(BasicBlock(planes, planes, mamba_layer=mamba_layer))
return nn.Sequential(*layers)
class MambaLayer(nn.Module):
def __init__(self, dim, d_state=16, d_conv=4, expand=2):
super().__init__()
self.dim = dim
self.nin = conv1x1(dim, dim)
self.nin2 = conv1x1(dim, dim)
self.norm2 = nn.BatchNorm3d(dim) # LayerNorm
self.relu2 = nn.ReLU(inplace=True)
self.relu3 = nn.ReLU(inplace=True)
self.norm = nn.BatchNorm3d(dim) # LayerNorm
self.relu = nn.ReLU(inplace=True)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand # Block expansion factor
)
def forward(self, x):
B, C = x.shape[:2]
x = self.nin(x)
x = self.norm(x)
x = self.relu(x)
act_x = x
assert C == self.dim
n_tokens = x.shape[2:].numel()
img_dims = x.shape[2:]
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
# x_mamba = self.mamba(x_flat)
x_flip_l = torch.flip(x_flat, dims=[2])
x_flip_c = torch.flip(x_flat, dims=[1])
x_flip_lc = torch.flip(x_flat, dims=[1,2])
x_ori = self.mamba(x_flat)
x_mamba_l = self.mamba(x_flip_l)
x_mamba_c = self.mamba(x_flip_c)
x_mamba_lc = self.mamba(x_flip_lc)
x_ori_l = torch.flip(x_mamba_l, dims=[2])
x_ori_c = torch.flip(x_mamba_c, dims=[1])
x_ori_lc = torch.flip(x_mamba_lc, dims=[1,2])
x_mamba = (x_ori+x_ori_l+x_ori_c+x_ori_lc)/4
out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
# act_x = self.relu3(x)
out += act_x
out = self.nin2(out)
out = self.norm2(out)
out = self.relu2(out)
return out
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, kernel_size=3):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=int(kernel_size / 2)),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True),
nn.Conv3d(out_ch, out_ch, 3, padding=1, dilation=1),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, input):
return self.conv(input)
class SingleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(SingleConv, self).__init__()
self.conv = nn.Sequential(nn.Conv3d(in_ch, out_ch, 3, padding=1), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True))
def forward(self, input):
return self.conv(input)
class Attentionlayer(nn.Module):
def __init__(self,dim,r=16,act='relu'):
super(Attentionlayer, self).__init__()
self.layer1 = nn.Linear(dim, int(dim//r))
self.layer2 = nn.Linear(int(dim//r), dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inp):
att = self.sigmoid(self.layer2(self.relu(self.layer1(inp))))
return att.unsqueeze(-1)
class nnMamba(nn.Module):
def __init__(self, in_ch=1, channels=32, blocks=3, number_classes=2):
super(nnMamba, self).__init__()
self.in_conv = DoubleConv(in_ch, channels, stride=2, kernel_size=3)
# self.mamba_layer_stem = MambaLayer(channels)
self.pooling = nn.AdaptiveAvgPool3d((1, 1, 1))
self.att1 = Attentionlayer(channels)
self.layer1 = make_res_layer(channels, channels * 2, blocks, stride=2, mamba_layer=MambaLayer(channels*2))
self.att2 = Attentionlayer(channels*2)
self.layer2 = make_res_layer(channels * 2, channels * 4, blocks, stride=2, mamba_layer=MambaLayer(channels*4))
# self.mamba_layer_2 = MambaLayer(channels*4)
self.att3 = Attentionlayer(channels*4)
self.layer3 = make_res_layer(channels * 4, channels * 8, blocks, stride=2, mamba_layer=MambaLayer(channels*8))
# self.mamba_layer_3 = MambaLayer(channels*8)
self.up5 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
self.conv5 = DoubleConv(channels * 12, channels * 4)
self.up6 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
self.conv6 = DoubleConv(channels * 6, channels * 2)
self.up7 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
self.conv7 = DoubleConv(channels * 3, channels)
self.up8 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
self.conv8 = DoubleConv(channels, number_classes)
def forward(self, x):
c1 = self.in_conv(x)
scale_f1 = self.att1(self.pooling(c1).reshape(c1.shape[0], c1.shape[1])).reshape(c1.shape[0], c1.shape[1], 1, 1, 1)
# c1_s = self.mamba_layer_stem(c1) + c1
c2 = self.layer1(c1)
# c2_s = self.mamba_layer_1(c2) + c2
scale_f2 = self.att2(self.pooling(c2).reshape(c2.shape[0], c2.shape[1])).reshape(c2.shape[0], c2.shape[1], 1, 1, 1)
c3 = self.layer2(c2)
# c3_s = self.mamba_layer_2(c3) + c3
scale_f3 = self.att3(self.pooling(c3).reshape(c3.shape[0], c3.shape[1])).reshape(c3.shape[0], c3.shape[1], 1, 1, 1)
c4 = self.layer3(c3)
# c4_s = self.mamba_layer_3(c4) + c4
up_5 = self.up5(c4)
merge5 = torch.cat([up_5, c3*scale_f3], dim=1)
c5 = self.conv5(merge5)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c2*scale_f2], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c1*scale_f1], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
c8 = self.conv8(up_8)
return c8
if __name__ == "__main__":
import numpy as np
input_shape = (2, 1, 128, 128, 128) # (batch_size, channels, height, width, depth)
input_data = torch.randn(input_shape).cuda()
model = nnMamba().cuda()
model_parameters = filter(lambda p: p.requires_grad , model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"总参数数量:{params / 1e6}M")
# 获取每个层的参数数量和名称
for name , param in model.named_parameters():
print(f"层名称: {name}, 参数形状: {param.shape}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6}M")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params / 1e6}M")
pass