-
Notifications
You must be signed in to change notification settings - Fork 433
/
LightCNN.py
200 lines (170 loc) · 7.28 KB
/
LightCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
'''
implement Light CNN
@author: Alfred Xiang Wu
@date: 2017.07.04
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class mfm(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
super(mfm, self).__init__()
self.out_channels = out_channels
if type == 1:
self.filter = nn.Conv2d(in_channels, 2*out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.filter = nn.Linear(in_channels, 2*out_channels)
def forward(self, x):
x = self.filter(x)
out = torch.split(x, self.out_channels, 1)
return torch.max(out[0], out[1])
class group(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(group, self).__init__()
self.conv_a = mfm(in_channels, in_channels, 1, 1, 0)
self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding)
def forward(self, x):
x = self.conv_a(x)
x = self.conv(x)
return x
class resblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(resblock, self).__init__()
self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
res = x
out = self.conv1(x)
out = self.conv2(out)
out = out + res
return out
class network_9layers(nn.Module):
def __init__(self, num_classes=79077):
super(network_9layers, self).__init__()
self.features = nn.Sequential(
mfm(1, 48, 5, 1, 2),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
group(48, 96, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
group(96, 192, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
group(192, 128, 3, 1, 1),
group(128, 128, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
)
self.fc1 = mfm(8*8*128, 256, type=0)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.dropout(x, training=self.training)
out = self.fc2(x)
return out, x
class network_29layers(nn.Module):
def __init__(self, block, layers, num_classes=79077):
super(network_29layers, self).__init__()
#self.conv1 = mfm(1, 48, 5, 1, 2)
self.conv1 = mfm(3, 48, 5, 1, 2)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block1 = self._make_layer(block, layers[0], 48, 48)
self.group1 = group(48, 96, 3, 1, 1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block2 = self._make_layer(block, layers[1], 96, 96)
self.group2 = group(96, 192, 3, 1, 1)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block3 = self._make_layer(block, layers[2], 192, 192)
self.group3 = group(192, 128, 3, 1, 1)
self.block4 = self._make_layer(block, layers[3], 128, 128)
self.group4 = group(128, 128, 3, 1, 1)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.fc = mfm(8*8*128, 256, type=0)
self.fc2 = nn.Linear(256, num_classes)
def _make_layer(self, block, num_blocks, in_channels, out_channels):
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.block1(x)
x = self.group1(x)
x = self.pool2(x)
x = self.block2(x)
x = self.group2(x)
x = self.pool3(x)
x = self.block3(x)
x = self.group3(x)
x = self.block4(x)
x = self.group4(x)
x = self.pool4(x)
x = x.view(x.size(0), -1)
fc = self.fc(x)
fc = F.dropout(fc, training=self.training)
out = self.fc2(fc)
return out, fc
class network_29layers_v2(nn.Module):
def __init__(self, block, layers, drop_ratio, out_h, out_w, feat_dim):
super(network_29layers_v2, self).__init__()
#self.conv1 = mfm(1, 48, 5, 1, 2)
self.conv1 = mfm(3, 48, 5, 1, 2)
self.block1 = self._make_layer(block, layers[0], 48, 48)
self.group1 = group(48, 96, 3, 1, 1)
self.block2 = self._make_layer(block, layers[1], 96, 96)
self.group2 = group(96, 192, 3, 1, 1)
self.block3 = self._make_layer(block, layers[2], 192, 192)
self.group3 = group(192, 128, 3, 1, 1)
self.block4 = self._make_layer(block, layers[3], 128, 128)
self.group4 = group(128, 128, 3, 1, 1)
#self.fc = nn.Linear(8*8*128, 256)
#self.fc2 = nn.Linear(256, num_classes, bias=False)
self.output_layer = nn.Sequential(nn.BatchNorm2d(128),
nn.Dropout(drop_ratio),
Flatten(),
nn.Linear(128 * out_h * out_w, feat_dim),
nn.BatchNorm1d(feat_dim))
def _make_layer(self, block, num_blocks, in_channels, out_channels):
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block1(x)
x = self.group1(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block2(x)
x = self.group2(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block3(x)
x = self.group3(x)
x = self.block4(x)
x = self.group4(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) # 7*7
#x = x.view(x.size(0), -1)
#fc = self.fc(x)
#x = F.dropout(fc, training=self.training)
#out = self.fc2(x)
#return out, fc
x = self.output_layer(x)
return x
def LightCNN_9Layers(drop_ratio, out_h, out_w, feat_dim):
model = network_9layers(drop_ratio, out_h, out_w, feat_dim)
return model
def LightCNN_29Layers(drop_ratio, out_h, out_w, feat_dim):
model = network_29layers(resblock, [1, 2, 3, 4], drop_ratio, out_h, out_w, feat_dim)
return model
def LightCNN_29Layers_v2(drop_ratio, out_h, out_w, feat_dim):
model = network_29layers_v2(resblock, [1, 2, 3, 4], drop_ratio, out_h, out_w, feat_dim)
return model
def LightCNN(depth, drop_ratio, out_h, out_w, feat_dim):
if depth == 9:
return LightCNN_9Layers(drop_ratio, out_h, out_w, feat_dim)
elif depth == 29:
return LightCNN_29Layers_v2(drop_ratio, out_h, out_w, feat_dim)