-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
upernet.py
93 lines (91 loc) · 5.05 KB
/
upernet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
'''
Function:
Implementation of UPerNet
Author:
Zhenchao Jin
'''
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import BaseSegmentor
from ..pspnet import PyramidPoolingModule
from ....utils import SSSegOutputStructure
from ...backbones import BuildActivation, BuildNormalization
'''UPerNet'''
class UPerNet(BaseSegmentor):
def __init__(self, cfg, mode):
super(UPerNet, self).__init__(cfg, mode)
align_corners, norm_cfg, act_cfg, head_cfg = self.align_corners, self.norm_cfg, self.act_cfg, cfg['head']
# build feature2pyramid
if 'feature2pyramid' in head_cfg:
from ..base import Feature2Pyramid
head_cfg['feature2pyramid']['norm_cfg'] = norm_cfg.copy()
self.feats_to_pyramid_net = Feature2Pyramid(**head_cfg['feature2pyramid'])
# build pyramid pooling module
ppm_cfg = {
'in_channels': head_cfg['in_channels_list'][-1], 'out_channels': head_cfg['feats_channels'], 'pool_scales': head_cfg['pool_scales'],
'align_corners': align_corners, 'norm_cfg': copy.deepcopy(norm_cfg), 'act_cfg': copy.deepcopy(act_cfg),
}
self.ppm_net = PyramidPoolingModule(**ppm_cfg)
# build lateral convs
act_cfg_copy = copy.deepcopy(act_cfg)
if 'inplace' in act_cfg_copy: act_cfg_copy['inplace'] = False
self.lateral_convs = nn.ModuleList()
for in_channels in head_cfg['in_channels_list'][:-1]:
self.lateral_convs.append(nn.Sequential(
nn.Conv2d(in_channels, head_cfg['feats_channels'], kernel_size=1, stride=1, padding=0, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg_copy),
))
# build fpn convs
self.fpn_convs = nn.ModuleList()
for in_channels in [head_cfg['feats_channels'],] * len(self.lateral_convs):
self.fpn_convs.append(nn.Sequential(
nn.Conv2d(in_channels, head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg_copy),
))
# build decoder
self.decoder = nn.Sequential(
nn.Conv2d(head_cfg['feats_channels'] * len(head_cfg['in_channels_list']), head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
nn.Dropout2d(head_cfg['dropout']),
nn.Conv2d(head_cfg['feats_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0)
)
# build auxiliary decoder
self.setauxiliarydecoder(cfg['auxiliary'])
# freeze normalization layer if necessary
if cfg.get('is_freeze_norm', False): self.freezenormalization()
'''forward'''
def forward(self, data_meta):
img_size = data_meta.images.size(2), data_meta.images.size(3)
# feed to backbone network
backbone_outputs = self.transforminputs(self.backbone_net(data_meta.images), selected_indices=self.cfg['backbone'].get('selected_indices'))
# feed to feats_to_pyramid_net
if hasattr(self, 'feats_to_pyramid_net'): backbone_outputs = self.feats_to_pyramid_net(backbone_outputs)
# feed to pyramid pooling module
ppm_out = self.ppm_net(backbone_outputs[-1])
# apply fpn
inputs = backbone_outputs[:-1]
lateral_outputs = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
lateral_outputs.append(ppm_out)
for i in range(len(lateral_outputs) - 1, 0, -1):
prev_shape = lateral_outputs[i - 1].shape[2:]
lateral_outputs[i - 1] = lateral_outputs[i - 1] + F.interpolate(lateral_outputs[i], size=prev_shape, mode='bilinear', align_corners=self.align_corners)
fpn_outputs = [self.fpn_convs[i](lateral_outputs[i]) for i in range(len(lateral_outputs) - 1)]
fpn_outputs.append(lateral_outputs[-1])
fpn_outputs = [F.interpolate(out, size=fpn_outputs[0].size()[2:], mode='bilinear', align_corners=self.align_corners) for out in fpn_outputs]
fpn_out = torch.cat(fpn_outputs, dim=1)
# feed to decoder
seg_logits = self.decoder(fpn_out)
# forward according to the mode
if self.mode in ['TRAIN', 'TRAIN_DEVELOP']:
loss, losses_log_dict = self.customizepredsandlosses(
seg_logits=seg_logits, annotations=data_meta.getannotations(), backbone_outputs=backbone_outputs, losses_cfg=self.cfg['losses'], img_size=img_size,
)
ssseg_outputs = SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict) if self.mode == 'TRAIN' else SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict, seg_logits=seg_logits)
else:
ssseg_outputs = SSSegOutputStructure(mode=self.mode, seg_logits=seg_logits)
return ssseg_outputs