-
Notifications
You must be signed in to change notification settings - Fork 2
/
alexnet.py
61 lines (55 loc) · 2.58 KB
/
alexnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.net = nn.Sequential(
# 为了和论文的图像输入尺寸保持一致以及下一层的55对应,这里对图像进行了padding
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=10e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=10e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Flatten(),
nn.Dropout(p=0.5),
nn.Linear(in_features=256 * 6 * 6, out_features=4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=4096, out_features=4096),
nn.ReLU(),
nn.Linear(in_features=4096, out_features=num_classes)
)
self.init_weights()
def init_weights(self):
for layer in self.net:
# 先一致初始化
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
# nn.init.normal_(layer.weight, mean=0, std=0.01) # 论文权重初始化策略
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0, std=0.01)
nn.init.constant_(layer.bias, 1)
# 单独对论文网络中的2、4、5卷积层的偏置进行初始化
nn.init.constant_(self.net[4].bias, 1)
nn.init.constant_(self.net[10].bias, 1)
nn.init.constant_(self.net[12].bias, 1)
def forward(self, x):
return self.net(x)
def test_output_shape(self):
test_img = torch.rand(size=(1, 3, 227, 227), dtype=torch.float32)
for layer in self.net:
test_img = layer(test_img)
print(layer.__class__.__name__, 'output shape: \t', test_img.shape)
# alexnet = AlexNet()
# alexnet.test_output_shape()