forked from HiroIshida/detic_ros
-
Notifications
You must be signed in to change notification settings - Fork 0
/
node_config.py
118 lines (99 loc) · 3.9 KB
/
node_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
from dataclasses import dataclass
import rospkg
import rospy
import torch
# Dirty but no way, because CenterNet2 is not package oriented
sys.path.insert(0, os.path.join(sys.path[0], 'third_party/CenterNet2/'))
from centernet.config import add_centernet_config
from detectron2.config import get_cfg
from detic.config import add_detic_config
@dataclass
class NodeConfig:
enable_pubsub: bool
out_debug_img: bool
out_debug_segimg: bool
verbose: bool
use_jsk_msgs: bool
vocabulary: str
custom_vocabulary: str
detic_config_path: str
model_weights_path: str
confidence_threshold: float
device_name: str
model_names = {
'swin': 'Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size',
'convnet': 'Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size',
'res50': 'Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size',
'res18': 'Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size',
}
@classmethod
def from_args(
cls,
model_type: str = 'swin',
enable_pubsub: bool = True,
out_debug_img: bool = True,
out_debug_segimg: bool = True,
verbose: bool = False,
use_jsk_msgs: bool = False,
confidence_threshold: float = 0.5,
device_name: str = 'auto',
vocabulary: str = 'lvis',
custom_vocabulary: str = ''):
if device_name == 'auto':
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
assert device_name in ['cpu', 'cuda']
assert model_type in NodeConfig.model_names
pack_path = rospkg.RosPack().get_path('detic_ros')
model_name = NodeConfig.model_names[model_type]
default_detic_config_path = os.path.join(
pack_path, 'detic_configs',
model_name + '.yaml')
default_model_weights_path = os.path.join(
pack_path, 'models',
model_name + '.pth')
return cls(
enable_pubsub,
out_debug_img,
out_debug_segimg,
verbose,
use_jsk_msgs,
vocabulary,
custom_vocabulary,
default_detic_config_path,
default_model_weights_path,
confidence_threshold,
device_name)
@classmethod
def from_rosparam(cls):
return cls.from_args(
rospy.get_param('~model_type', 'swin'),
rospy.get_param('~enable_pubsub', True),
rospy.get_param('~out_debug_img', True),
rospy.get_param('~out_debug_segimg', True),
rospy.get_param('~verbose', True),
rospy.get_param('~use_jsk_msgs', False),
rospy.get_param('~confidence_threshold', 0.5),
rospy.get_param('~device', 'auto'),
rospy.get_param('~vocabulary', 'lvis'),
rospy.get_param('~custom_vocabulary', ''))
def to_detectron_config(self):
cfg = get_cfg()
cfg.MODEL.DEVICE = self.device_name
add_centernet_config(cfg)
add_detic_config(cfg)
cfg.merge_from_file(self.detic_config_path)
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = self.confidence_threshold
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = self.confidence_threshold
cfg.merge_from_list(['MODEL.WEIGHTS', self.model_weights_path])
# Similar to https://github.com/facebookresearch/Detic/demo.py
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' # load later
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True
# Maybe should edit detic_configs/Base-C2_L_R5021k_640b64_4x.yaml
pack_path = rospkg.RosPack().get_path('detic_ros')
cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join(
pack_path, 'datasets/metadata/lvis_v1_train_cat_info.json')
cfg.freeze()
return cfg