-
Notifications
You must be signed in to change notification settings - Fork 0
/
bn_remove.py
136 lines (106 loc) · 4.23 KB
/
bn_remove.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import torchvision
import torchvision.transforms as transforms
# Instantiate model with BN and load trained parameters
class smallNetTrain(nn.Module) :
# CIFAR-10 data is 32*32 images with 3 RGB channels
def __init__(self, input_dim=3*32*32) :
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU()
)
self.fc1 = nn.Sequential(
nn.Linear(16*32*32, 32*32),
nn.BatchNorm1d(32*32),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(32*32, 10),
nn.ReLU()
)
def forward(self, x) :
x = self.conv1(x)
x = self.conv2(x)
x = x.float().view(-1, 16*32*32)
x = self.fc1(x)
x = self.fc2(x)
return x
model = smallNetTrain()
model.load_state_dict(torch.load("./smallNetSaved",map_location=torch.device('cpu')))
# Instantiate model without BN
class smallNetTest(nn.Module) :
# CIFAR-10 data is 32*32 images with 3 RGB channels
def __init__(self, input_dim=3*32*32) :
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU()
)
self.fc1 = nn.Sequential(
nn.Linear(16*32*32, 32*32),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(32*32, 10),
nn.ReLU()
)
def forward(self, x) :
x = self.conv1(x)
x = self.conv2(x)
x = x.float().view(-1, 16*32*32)
x = self.fc1(x)
x = self.fc2(x)
return x
model_test = smallNetTest()
# Initialize weights of model without BN
conv1_bn_beta, conv1_bn_gamma = model.conv1[1].bias, model.conv1[1].weight
conv1_bn_mean, conv1_bn_var = model.conv1[1].running_mean, model.conv1[1].running_var
conv2_bn_beta, conv2_bn_gamma = model.conv2[1].bias, model.conv2[1].weight
conv2_bn_mean, conv2_bn_var = model.conv2[1].running_mean, model.conv2[1].running_var
fc1_bn_beta, fc1_bn_gamma = model.fc1[1].bias, model.fc1[1].weight
fc1_bn_mean, fc1_bn_var = model.fc1[1].running_mean, model.fc1[1].running_var
eps = 1e-05
'''
# Initialize the following parameters
model_test.conv1[0].weight.data =
model_test.conv1[0].bias.data =
model_test.conv2[0].weight.data =
model_test.conv2[0].bias.data =
model_test.fc1[0].weight.data =
model_test.fc1[0].bias.data =
model_test.fc2[0].weight.data =
model_test.fc2[0].bias.data =
'''
# Verify difference between model and model_test
model.eval()
# model_test.eval() # not necessary since model_test has no BN or dropout
test_dataset = torchvision.datasets.CIFAR10(root='./cifar_10data/',
train=False,
transform=transforms.ToTensor(), download = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
diff = []
with torch.no_grad():
for images, _ in test_loader:
diff.append(torch.norm(model(images) - model_test(images))**2)
print(max(diff)) # If less than 1e-08, you got the right answer.
'''
For debugging purposes, you may want to match the output of conv1 first before
moving on working on conv2. To do so, you can replace the forward-evaluation
functions of the two models with
def forward(self, x) :
x = self.conv1(x)
return x
'''