-
Notifications
You must be signed in to change notification settings - Fork 1
/
ptta.py
115 lines (91 loc) · 3.26 KB
/
ptta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import torch
import argparse
from core.configs import cfg
from core.utils import *
from core.model import build_model
from core.data import build_loader
from core.optim import build_optimizer
from core.adapter import build_adapter
from tqdm import tqdm
from setproctitle import setproctitle
def testTimeAdaptation(cfg):
logger = logging.getLogger("TTA.test_time")
# model, optimizer
model = build_model(cfg)
optimizer = build_optimizer(cfg)
tta_adapter = build_adapter(cfg)
tta_model = tta_adapter(cfg, model, optimizer)
tta_model.cuda()
loader, processor = build_loader(cfg, cfg.CORRUPTION.DATASET, cfg.CORRUPTION.TYPE, cfg.CORRUPTION.SEVERITY)
tbar = tqdm(loader)
for batch_id, data_package in enumerate(tbar):
data, label, domain = data_package["image"], data_package['label'], data_package['domain']
if len(label) == 1:
continue # ignore the final single point
data, label = data.cuda(), label.cuda()
output = tta_model(data)
predict = torch.argmax(output, dim=1)
accurate = (predict == label)
processor.process(accurate, domain)
if batch_id % 10 == 0:
if hasattr(tta_model, "mem"):
tbar.set_postfix(acc=processor.cumulative_acc(), bank=tta_model.mem.get_occupancy())
else:
tbar.set_postfix(acc=processor.cumulative_acc())
processor.calculate()
logger.info(f"All Results\n{processor.info()}")
def main():
parser = argparse.ArgumentParser("Pytorch Implementation for Test Time Adaptation!")
parser.add_argument(
'-acfg',
'--adapter-config-file',
metavar="FILE",
default="",
help="path to adapter config file",
type=str)
parser.add_argument(
'-dcfg',
'--dataset-config-file',
metavar="FILE",
default="",
help="path to dataset config file",
type=str)
parser.add_argument(
'-ocfg',
'--order-config-file',
metavar="FILE",
default="",
help="path to order config file",
type=str)
parser.add_argument(
'opts',
help='modify the configuration by command line',
nargs=argparse.REMAINDER,
default=None)
args = parser.parse_args()
if len(args.opts) > 0:
args.opts[-1] = args.opts[-1].strip('\r\n')
torch.backends.cudnn.benchmark = True
cfg.merge_from_file(args.adapter_config_file)
cfg.merge_from_file(args.dataset_config_file)
if not args.order_config_file == "":
cfg.merge_from_file(args.order_config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
ds = cfg.CORRUPTION.DATASET
adapter = cfg.ADAPTER.NAME
setproctitle(f"TTA:{ds:>8s}:{adapter:<10s}")
if cfg.OUTPUT_DIR:
mkdir(cfg.OUTPUT_DIR)
logger = setup_logger('TTA', cfg.OUTPUT_DIR, 0, filename=cfg.LOG_DEST)
logger.info(args)
logger.info(f"Loaded configuration file: \n"
f"\tadapter: {args.adapter_config_file}\n"
f"\tdataset: {args.dataset_config_file}\n"
f"\torder: {args.order_config_file}")
logger.info("Running with config:\n{}".format(cfg))
set_random_seed(cfg.SEED)
testTimeAdaptation(cfg)
if __name__ == "__main__":
main()