-
Notifications
You must be signed in to change notification settings - Fork 5
/
local_unsupervised.py
executable file
·127 lines (93 loc) · 4.61 KB
/
local_unsupervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from torch.utils.data import DataLoader
import copy
import torch
import torch.optim
import torch.nn.functional as F
import torch.nn as nn
from options import args_parser
from utils import losses
from utils.util import get_timestamp, calculate_bank
args = args_parser()
class UnsupervisedLocalUpdate(object):
def __init__(self, args, dataset, Pi, priors_corr):
self.dataset = dataset
self.ldr_train = DataLoader(self.dataset, batch_size = args.batch_size, shuffle = True, drop_last=True)
self.epoch = 0
self.iter_num = 0
self.flag = True
self.base_lr = 2e-4
self.Pi = Pi
self.priors_corr = priors_corr
self.temp_bank = []
self.permanent_bank = set()
self.real_Pi = list(Pi.numpy())
def train(self, args, net, op_dict, epoch, logging):
net.cuda()
net.train()
self.optimizer = torch.optim.Adam(net.parameters(), lr=args.base_lr, betas=(0.9, 0.999), weight_decay=5e-4)
self.optimizer.load_state_dict(op_dict)
loss_fun = nn.CrossEntropyLoss()
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.base_lr
self.epoch = epoch
epoch_loss = []
print(" Inference priors")
Pi = torch.zeros_like(self.Pi.float()).cpu().int()
net.eval()
self.dataset.re_load()
self.temp_bank = []
for i, (items, _, image_batch, label_batch) in enumerate(self.ldr_train):
image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
inputs = image_batch
representations, logits, outputs = net(inputs)
max_probs, pseudo_labels = torch.max(outputs.cpu(),dim=1)
# no confidence filter
items = list(items)
lp_conf = args.hi_lp
sample_ids = torch.where(max_probs>lp_conf)
bank_items = []
for item_id in list(sample_ids[0].cpu().numpy()):
bank_items.append(items[item_id])
for bank_item in bank_items :
self.permanent_bank.add(bank_item)
lp_conf = args.lo_lp
sample_ids = torch.where(max_probs>lp_conf)
bank_items = []
for item_id in list(sample_ids[0].cpu().numpy()):
bank_items.append(items[item_id])
self.temp_bank = self.temp_bank + bank_items
net.eval()
self.dataset.update(calculate_bank(self.temp_bank, self.permanent_bank))
for i, (items, _, image_batch, label_batch) in enumerate(self.ldr_train):
image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
inputs = image_batch
representations, logits, outputs = net(inputs)
max_probs, pseudo_labels = torch.max(outputs.cpu(),dim=1)
label_batch = list(label_batch.int().cpu().numpy())
for i in range(len(label_batch)):
Pi[label_batch[i]][pseudo_labels[i]] += 1
Pi = Pi.float().cuda()
Pi = F.normalize(Pi, p=1)
priors_corr = self.priors_corr.float().cuda()
print(' Unsupervised training')
net.train()
self.dataset.update(self.temp_bank)
for epoch in range(args.local_ep):
batch_loss = []
iter_max = len(self.ldr_train)
for i, (_, _, image_batch, label_batch) in enumerate(self.ldr_train):
image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
inputs = image_batch
representations, logits, outputs = net(inputs, Pi=Pi, priors_corr=priors_corr)
loss = loss_fun(outputs, label_batch.long())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
batch_loss.append(loss.item())
self.iter_num = self.iter_num + 1
self.epoch = self.epoch + 1
epoch_loss.append(sum(batch_loss) / len(batch_loss))
print(f' Local Loss: {epoch_loss}')
net.cpu()
net_states = net.state_dict()
return net_states, sum(epoch_loss) / len(epoch_loss), copy.deepcopy(self.optimizer.state_dict())