-
Notifications
You must be signed in to change notification settings - Fork 98
/
quantization_aware_training.py
119 lines (88 loc) · 3.66 KB
/
quantization_aware_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from model import *
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
batch_size = 64
seed = 1
epochs = 3
lr = 0.01
momentum = 0.5
using_bn = True
load_quant_model_file = None
# load_quant_model_file = "ckpt/mnist_cnnbn_qat.pt"
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
if using_bn:
model = NetBN()
model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt', map_location='cpu'))
save_file = "ckpt/mnist_cnnbn_qat.pt"
else:
model = Net()
model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
save_file = "ckpt/mnist_cnn_qat.pt"
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
model.eval()
full_inference(model, test_loader)
num_bits = 8
model.quantize(num_bits=num_bits)
print('Quantization bit: %d' % num_bits)
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
model.train()
for epoch in range(1, epochs + 1):
quantize_aware_training(model, device, train_loader, optimizer, epoch)
model.eval()
torch.save(model.state_dict(), save_file)
model.freeze()
quantize_inference(model, test_loader)