-
Notifications
You must be signed in to change notification settings - Fork 0
/
encrytedNN.py
156 lines (121 loc) · 4.97 KB
/
encrytedNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
epochs = 1
# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers
# n_train_items = 320
# n_test_items = 320
n_train_items = 784 * 128
n_test_items = 784 * 128
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time
class Arguments():
def __init__(self):
self.batch_size = 128
self.test_batch_size = 128
self.epochs = epochs
self.lr = 0.02
self.seed = 1
self.log_interval = 1 # Log info at each batch
self.precision_fractional = 3
self.normalize_mean = 0.1307
self.normalize_std = 0.3081
self.feature_num = 784
args = Arguments()
_ = torch.manual_seed(args.seed)
import syft as sy # import the Pysyft library
import participants as pt
owners = ['alice', 'bob']
crypto_provider = 'crypto_provider'
parties = pt.Parties()
parties.init_parties(owners, crypto_provider)
def get_private_data_loaders(workers, crypto_provider):
# import project modules
import securefe as sfe
import dataloader as dld
import time
start_time = time.time()
train_loader = dld.Dataloader()
train_loader.load_dataset("MNIST", isTrain=True, batch_size=args.batch_size)
args.feature_num = 784
test_loader = dld.Dataloader()
test_loader.load_dataset("MNIST", isTrain=False, batch_size=args.test_batch_size)
mean = args.normalize_mean
std = args.normalize_std
private_train_loader = [
(sfe.secure_normalize(sfe.secret_share(data, workers, crypto_provider), mean, std),
sfe.secret_share(sfe.one_hot_of(target), workers, crypto_provider))
for i, (data, target) in enumerate(train_loader.loader)
if i < n_train_items / args.batch_size
]
private_test_loader = [
(sfe.secure_normalize(sfe.secret_share(data, workers, crypto_provider), mean, std),
sfe.secret_share(target.float(), workers, crypto_provider))
for i, (data, target) in enumerate(test_loader.loader)
if i < n_test_items / args.test_batch_size
]
print("Normalizing time: {:.2f}s".format(time.time() - start_time))
return private_train_loader, private_test_loader
private_train_loader, private_test_loader = get_private_data_loaders(
workers=parties.data_owners,
crypto_provider=parties.crypto_provider
)
# print(private_test_loader)
# exit()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# TODO: 如果特征数减少了,岂不是输入参数的大小也会减少?
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
# x = x.view(-1, args.feature_num) view() is not working here, use torch.Tensor.reshape() instead
x = torch.Tensor.reshape(x, (-1, args.feature_num))
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
def train(args, model, private_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset
start_time = time.time()
optimizer.zero_grad()
output = model(data)
# loss = F.nll_loss(output, target) <-- not possible here
batch_size = output.shape[0]
loss = ((output - target)**2).sum().refresh()/batch_size
loss.backward()
# print(loss.get())
# exit()
optimizer.step()
if batch_idx % args.log_interval == 0:
loss = loss.get().float_precision()
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.3f}s'.format(
epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,
100. * batch_idx / len(private_train_loader), loss.item(), time.time() - start_time))
def test(args, model, private_test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in private_test_loader:
start_time = time.time()
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target.view_as(pred)).sum()
correct = correct.get().float_precision()
print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
correct.item(), len(private_test_loader)* args.test_batch_size,
100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))
model = Net()
model = model.fix_precision().share(*parties.data_owners, crypto_provider=parties.crypto_provider, protocol="fss", requires_grad=True)
optimizer = optim.SGD(model.parameters(), lr=args.lr)
optimizer = optimizer.fix_precision()
model.train()
for epoch in range(1, args.epochs + 1):
train(args, model, private_train_loader, optimizer, epoch)
test(args, model, private_test_loader)