-
Notifications
You must be signed in to change notification settings - Fork 140
/
main.py
67 lines (56 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Main script for ADDA."""
import params
from core import eval_src, eval_tgt, train_src, train_tgt
from models import Discriminator, LeNetClassifier, LeNetEncoder
from utils import get_data_loader, init_model, init_random_seed
if __name__ == '__main__':
# init random seed
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(params.src_dataset)
src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset)
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)
# load models
src_encoder = init_model(net=LeNetEncoder(),
restore=params.src_encoder_restore)
src_classifier = init_model(net=LeNetClassifier(),
restore=params.src_classifier_restore)
tgt_encoder = init_model(net=LeNetEncoder(),
restore=params.tgt_encoder_restore)
critic = init_model(Discriminator(input_dims=params.d_input_dims,
hidden_dims=params.d_hidden_dims,
output_dims=params.d_output_dims),
restore=params.d_model_restore)
# train source model
print("=== Training classifier for source domain ===")
print(">>> Source Encoder <<<")
print(src_encoder)
print(">>> Source Classifier <<<")
print(src_classifier)
if not (src_encoder.restored and src_classifier.restored and
params.src_model_trained):
src_encoder, src_classifier = train_src(
src_encoder, src_classifier, src_data_loader)
# eval source model
print("=== Evaluating classifier for source domain ===")
eval_src(src_encoder, src_classifier, src_data_loader_eval)
# train target encoder by GAN
print("=== Training encoder for target domain ===")
print(">>> Target Encoder <<<")
print(tgt_encoder)
print(">>> Critic <<<")
print(critic)
# init weights of target encoder with those of source encoder
if not tgt_encoder.restored:
tgt_encoder.load_state_dict(src_encoder.state_dict())
if not (tgt_encoder.restored and critic.restored and
params.tgt_model_trained):
tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
src_data_loader, tgt_data_loader)
# eval target encoder on test set of target dataset
print("=== Evaluating classifier for encoded target domain ===")
print(">>> source only <<<")
eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
print(">>> domain adaption <<<")
eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)