-
Notifications
You must be signed in to change notification settings - Fork 37
/
CDAN.py
179 lines (147 loc) · 7.52 KB
/
CDAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
'''
Paper: Long, M., Cao, Z., Wang, J. and Jordan, M.I., 2018. Conditional adversarial
domain adaptation. Advances in neural information processing systems, 31.
Reference code: https://github.com/thuml/Transfer-Learning-Library
'''
import torch
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import utils
import modules
from train_utils import TrainerBase
class RandomizedMultiLinearMap(nn.Module):
def __init__(self, features_dim: int, num_classes: int, output_dim: int = 1024):
super(RandomizedMultiLinearMap, self).__init__()
self.Rf = torch.randn(features_dim, output_dim)
self.Rg = torch.randn(num_classes, output_dim)
self.output_dim = output_dim
def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
f = torch.mm(f, self.Rf.to(f.device))
g = torch.mm(g, self.Rg.to(g.device))
output = torch.mul(f, g) / np.sqrt(float(self.output_dim))
return output
class MultiLinearMap(nn.Module):
def __init__(self):
super(MultiLinearMap, self).__init__()
def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
batch_size = f.size(0)
output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))
return output.view(batch_size, -1)
class ConditionalDomainAdversarialLoss(nn.Module):
def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: bool = False,
randomized: bool = False, num_classes: int = -1,
features_dim: int = -1, randomized_dim: int = 1024,
reduction: str = 'mean', sigmoid=True, grl = None):
super(ConditionalDomainAdversarialLoss, self).__init__()
self.domain_discriminator = domain_discriminator
self.grl = utils.WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) \
if grl is None else grl
self.entropy_conditioning = entropy_conditioning
self.sigmoid = sigmoid
self.reduction = reduction
if randomized:
assert num_classes > 0 and features_dim > 0 and randomized_dim > 0
self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim)
else:
self.map = MultiLinearMap()
self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight,
reduction=reduction) if self.entropy_conditioning \
else F.binary_cross_entropy(input, target, reduction=reduction)
self.domain_discriminator_accuracy = None
def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
f = torch.cat((f_s, f_t), dim=0)
g = torch.cat((g_s, g_t), dim=0)
g = F.softmax(g, dim=1).detach()
h = self.grl(self.map(f, g))
d = self.domain_discriminator(h)
weight = 1.0 + torch.exp(-utils.entropy(g))
batch_size = f.size(0)
weight = weight / torch.sum(weight) * batch_size
if self.sigmoid:
d_label = torch.cat((
torch.ones((g_s.size(0), 1)).to(g_s.device),
torch.zeros((g_t.size(0), 1)).to(g_t.device),
))
self.domain_discriminator_accuracy = utils.binary_accuracy(d, d_label)
if self.entropy_conditioning:
return F.binary_cross_entropy(d, d_label, weight.view_as(d), reduction=self.reduction)
else:
return F.binary_cross_entropy(d, d_label, reduction=self.reduction)
else:
d_label = torch.cat((
torch.ones((g_s.size(0), )).to(g_s.device),
torch.zeros((g_t.size(0), )).to(g_t.device),
)).long()
self.domain_discriminator_accuracy = utils.get_accuracy(d, d_label)
if self.entropy_conditioning:
raise NotImplementedError("entropy_conditioning")
return F.cross_entropy(d, d_label, reduction=self.reduction)
class Trainer(TrainerBase):
def __init__(self, args):
super(Trainer, self).__init__(args)
self.model = modules.ClassifierBase(input_size=1, num_classes=args.num_classes[0],
backbone=args.backbone, dropout=args.dropout).to(self.device)
self.domain_discri = modules.MLP(input_size=self.model.feature_dim*args.num_classes[0], output_size=1,
dropout=args.dropout, last='sigmoid').to(self.device)
grl = utils.GradientReverseLayer()
self.domain_adv = ConditionalDomainAdversarialLoss(self.domain_discri, grl=grl)
self._init_data()
if args.train_mode == 'single_source':
self.src = args.source_name[0]
elif args.train_mode == 'source_combine':
self.src = 'concat_source'
elif args.train_mode == 'multi_source':
raise Exception("This model cannot be trained in multi_source mode.")
self.optimizer = self._get_optimizer([self.model, self.domain_discri])
self.lr_scheduler = self._get_lr_scheduler(self.optimizer)
self.num_iter = len(self.dataloaders[self.src])
def save_model(self):
torch.save({
'model': self.model.state_dict()
}, self.args.save_path + '.pth')
logging.info('Model saved to {}'.format(self.args.save_path + '.pth'))
def load_model(self):
logging.info('Loading model from {}'.format(self.args.load_path))
ckpt = torch.load(self.args.load_path)
self.model.load_state_dict(ckpt['model'])
def _set_to_train(self):
self.model.train()
self.domain_discri.train()
def _set_to_eval(self):
self.model.eval()
def _train_one_epoch(self, epoch_acc, epoch_loss):
for _ in tqdm(range(self.num_iter), ascii=True):
# obtain data
target_data, _ = self._get_next_batch('train')
source_data, source_labels = self._get_next_batch(self.src)
# forward
self.optimizer.zero_grad()
data = torch.cat((source_data, target_data), dim=0)
y, f = self.model(data)
f_s, f_t = f.chunk(2, dim=0)
y_s, y_t = y.chunk(2, dim=0)
# compute loss
loss_c = F.cross_entropy(y_s, source_labels)
loss_d = self.domain_adv(y_s, f_s, y_t, f_t)
loss = loss_c + self.tradeoff[0] * loss_d
# log information
epoch_acc['Source Data'] += self._get_accuracy(y_s, source_labels)
epoch_acc['Discriminator'] += self.domain_adv.domain_discriminator_accuracy
epoch_loss['Source Classifier'] += loss_c
epoch_loss['Discriminator'] += loss_d
# backward
loss.backward()
self.optimizer.step()
return epoch_acc, epoch_loss
def _eval(self, data, actual_labels, correct, total):
pred = self.model(data)
actual_pred = self._get_actual_label(pred, idx=0)
output = self._get_accuracy(actual_pred, actual_labels, return_acc=False)
correct['acc'] += output[0]; total['acc'] += output[1]
if self.args.da_scenario in ['open-set', 'universal']:
output = self._get_accuracy(actual_pred, actual_labels, return_acc=False, idx=0, mode='closed-set')
correct['Closed-set-acc'] += output[0]; total['Closed-set-acc'] += output[1]
return correct, total