-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dct.py
77 lines (59 loc) · 2.57 KB
/
test_dct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import logging
import os
import pprint
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import yaml
from dataset.semi_dct import SemiDatasetDCT
from model.semseg.segmentor import Segmentor
from supervised_dct import evaluate
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import count_params, init_log
from util.dist_helper import setup_distributed
from util.classes import CLASSES
parser = argparse.ArgumentParser(description='Semi-Supervised Semantic Segmentation')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--port', default=None, type=int)
def main():
args = parser.parse_args()
cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
logger = init_log('global', logging.INFO)
logger.propagate = 0
rank, word_size = setup_distributed(port=args.port)
if rank == 0:
print('{}\n'.format(pprint.pformat(cfg)))
cudnn.enabled = True
cudnn.benchmark = True
model = Segmentor(cfg)
model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
# print(model)
if rank == 0:
print('Total params: {:.1f}M\n'.format(count_params(model)))
local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=False)
valset = SemiDatasetDCT(cfg['dataset'], cfg['data_root'], 'val')
valsampler = torch.utils.data.distributed.DistributedSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=2,
drop_last=False, sampler=valsampler)
if cfg['dataset'] == 'cityscapes':
eval_mode = 'sliding_window'
else:
eval_mode = 'original'
mIOU, iou_class = evaluate(model, valloader, eval_mode, cfg, local_rank)
if rank == 0:
print('***** Evaluation {} ***** >>>> meanIOU: {:.3f}\n'.format(eval_mode, mIOU))
iou_class = [(cls_idx, iou) for cls_idx, iou in enumerate(iou_class)]
iou_class.sort(key=lambda x:x[1])
for (cls_idx, iou) in iou_class:
print('***** Evaluation ***** >>>> Class [{:} {:}] IoU: {:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], iou))
if __name__ == '__main__':
main()